diff --git a/api/.env.example b/api/.env.example index 833d83797d..f645ba7bf0 100644 --- a/api/.env.example +++ b/api/.env.example @@ -657,6 +657,7 @@ PLUGIN_REMOTE_INSTALL_PORT=5003 PLUGIN_REMOTE_INSTALL_HOST=localhost PLUGIN_MAX_PACKAGE_SIZE=15728640 PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 +PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400 INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 # Marketplace configuration diff --git a/api/commands/plugin.py b/api/commands/plugin.py index 8ad2321b07..e1b3cf0fa1 100644 --- a/api/commands/plugin.py +++ b/api/commands/plugin.py @@ -11,6 +11,7 @@ from configs import dify_config from core.helper import encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.plugin import PluginInstaller +from core.plugin.plugin_service import PluginService from core.tools.utils.system_encryption import encrypt_system_params from extensions.ext_database import db from models import Tenant @@ -20,7 +21,6 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding from models.tools import ToolOAuthSystemClient from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index b99b2dc6b5..5083bb11f6 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -265,6 +265,11 @@ class PluginConfig(BaseSettings): default=60 * 60, ) + PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field( + description="TTL in seconds for caching tenant plugin model providers in Redis", + default=60 * 60 * 24, + ) + PLUGIN_MAX_FILE_SIZE: PositiveInt = Field( description="Maximum allowed size (bytes) for plugin-generated files", default=50 * 1024 * 1024, diff --git a/api/configs/remote_settings_sources/apollo/__init__.py b/api/configs/remote_settings_sources/apollo/__init__.py index 55c14ead56..d017b86ad5 100644 --- a/api/configs/remote_settings_sources/apollo/__init__.py +++ b/api/configs/remote_settings_sources/apollo/__init__.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, override from pydantic import Field from pydantic.fields import FieldInfo @@ -48,6 +48,7 @@ class ApolloSettingsSource(RemoteSettingsSource): self.namespace = configs["APOLLO_NAMESPACE"] self.remote_configs = self.client.get_all_dicts(self.namespace) + @override def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: if not isinstance(self.remote_configs, dict): raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}") diff --git a/api/configs/remote_settings_sources/nacos/__init__.py b/api/configs/remote_settings_sources/nacos/__init__.py index f3e6306753..ddef8a5f49 100644 --- a/api/configs/remote_settings_sources/nacos/__init__.py +++ b/api/configs/remote_settings_sources/nacos/__init__.py @@ -1,7 +1,7 @@ import logging import os from collections.abc import Mapping -from typing import Any +from typing import Any, override from pydantic.fields import FieldInfo @@ -41,6 +41,7 @@ class NacosSettingsSource(RemoteSettingsSource): except Exception as e: raise RuntimeError(f"Failed to parse config: {e}") + @override def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: field_value = self.remote_configs.get(field_name) if field_value is None: diff --git a/api/context/execution_context.py b/api/context/execution_context.py index e687dfc4b1..6fb3ca1971 100644 --- a/api/context/execution_context.py +++ b/api/context/execution_context.py @@ -10,7 +10,7 @@ import threading from abc import ABC, abstractmethod from collections.abc import Callable, Generator from contextlib import AbstractContextManager, contextmanager -from typing import Any, Protocol, final, runtime_checkable +from typing import Any, Protocol, final, override, runtime_checkable from pydantic import BaseModel @@ -133,10 +133,12 @@ class NullAppContext(AppContext): self._config = config or {} self._extensions: dict[str, Any] = {} + @override def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value by key.""" return self._config.get(key, default) + @override def get_extension(self, name: str) -> Any: """Get extension by name.""" return self._extensions.get(name) @@ -146,6 +148,7 @@ class NullAppContext(AppContext): self._extensions[name] = extension @contextmanager + @override def enter(self) -> Generator[None, None, None]: """Enter null context (no-op).""" yield diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index eddd6448d8..1201bad041 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -6,7 +6,7 @@ import contextvars import threading from collections.abc import Generator from contextlib import contextmanager -from typing import Any, final +from typing import Any, final, override from flask import Flask, current_app, g @@ -30,15 +30,18 @@ class FlaskAppContext(AppContext): """ self._flask_app = flask_app + @override def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value from Flask app config.""" return self._flask_app.config.get(key, default) + @override def get_extension(self, name: str) -> Any: """Get Flask extension by name.""" return self._flask_app.extensions.get(name) @contextmanager + @override def enter(self) -> Generator[None, None, None]: """Enter Flask app context.""" with self._flask_app.app_context(): diff --git a/api/controllers/console/agent/roster.py b/api/controllers/console/agent/roster.py index 3334f1bb2d..daa0d97496 100644 --- a/api/controllers/console/agent/roster.py +++ b/api/controllers/console/agent/roster.py @@ -1,3 +1,5 @@ +from uuid import UUID + from flask import request from flask_restx import Resource from pydantic import BaseModel, Field @@ -80,7 +82,7 @@ class AgentRosterDetailApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, agent_id): + def get(self, agent_id: UUID): _, tenant_id = current_account_with_tenant() return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)) @@ -89,7 +91,7 @@ class AgentRosterDetailApi(Resource): @login_required @account_initialization_required @edit_permission_required - def patch(self, agent_id): + def patch(self, agent_id: UUID): account, tenant_id = current_account_with_tenant() payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {}) return _agent_roster_service().update_roster_agent( @@ -100,7 +102,7 @@ class AgentRosterDetailApi(Resource): @login_required @account_initialization_required @edit_permission_required - def delete(self, agent_id): + def delete(self, agent_id: UUID): account, tenant_id = current_account_with_tenant() _agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id) return "", 204 @@ -111,7 +113,7 @@ class AgentRosterVersionsApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, agent_id): + def get(self, agent_id: UUID): _, tenant_id = current_account_with_tenant() return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))} @@ -121,7 +123,7 @@ class AgentRosterVersionDetailApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, agent_id, version_id): + def get(self, agent_id: UUID, version_id: UUID): _, tenant_id = current_account_with_tenant() return _agent_roster_service().get_agent_version_detail( tenant_id=tenant_id, diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index aca22d5c5a..133c57d34d 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,4 +1,5 @@ from datetime import datetime +from uuid import UUID import flask_restx from flask_restx import Resource @@ -155,7 +156,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc(description="Get all API keys for an app") @console_ns.doc(params={"resource_id": "App ID"}) @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) - def get(self, resource_id): # type: ignore + def get(self, resource_id: UUID): """Get all API keys for an app""" return super().get(resource_id) @@ -164,7 +165,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc(params={"resource_id": "App ID"}) @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") - def post(self, resource_id): # type: ignore + def post(self, resource_id: UUID): """Create a new API key for an app""" return super().post(resource_id) @@ -180,9 +181,9 @@ class AppApiKeyResource(BaseApiKeyResource): @console_ns.doc(description="Delete an API key for an app") @console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"}) @console_ns.response(204, "API key deleted successfully") - def delete(self, resource_id, api_key_id): + def delete(self, resource_id: UUID, api_key_id: UUID): """Delete an API key for an app""" - return super().delete(resource_id, api_key_id) + return super().delete(str(resource_id), str(api_key_id)) resource_type = ApiTokenType.APP resource_model = App @@ -195,7 +196,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc(description="Get all API keys for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) - def get(self, resource_id): # type: ignore + def get(self, resource_id: UUID): """Get all API keys for a dataset""" return super().get(resource_id) @@ -204,7 +205,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc(params={"resource_id": "Dataset ID"}) @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") - def post(self, resource_id): # type: ignore + def post(self, resource_id: UUID): """Create a new API key for a dataset""" return super().post(resource_id) @@ -220,9 +221,9 @@ class DatasetApiKeyResource(BaseApiKeyResource): @console_ns.doc(description="Delete an API key for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"}) @console_ns.response(204, "API key deleted successfully") - def delete(self, resource_id, api_key_id): + def delete(self, resource_id: UUID, api_key_id: UUID): """Delete an API key for a dataset""" - return super().delete(resource_id, api_key_id) + return super().delete(str(resource_id), str(api_key_id)) resource_type = ApiTokenType.DATASET resource_model = Dataset diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index bf8b57685f..d066177df3 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -159,13 +159,15 @@ class AppAnnotationSettingUpdateApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, app_id: UUID, annotation_setting_id): - annotation_setting_id = str(annotation_setting_id) + def post(self, app_id: UUID, annotation_setting_id: UUID): + annotation_setting_id_str = str(annotation_setting_id) args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold} - result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args) + result = AppAnnotationService.update_app_annotation_setting( + str(app_id), annotation_setting_id_str, setting_args + ) return result, 200 @@ -181,9 +183,9 @@ class AnnotationReplyActionStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def get(self, app_id: UUID, job_id, action): - job_id = str(job_id) - app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" + def get(self, app_id: UUID, job_id: UUID, action: str): + job_id_str = str(job_id) + app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}" cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job does not exist.") @@ -191,10 +193,10 @@ class AnnotationReplyActionStatusApi(Resource): job_status = cache_result.decode() error_msg = "" if job_status == "error": - app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}" + app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}" error_msg = redis_client.get(app_annotation_error_key).decode() - return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 + return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200 @console_ns.route("/apps//annotations") diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index b653016319..e8e8234ac4 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, import_id): + def post(self, import_id: str): # Check user role first current_user, _ = current_account_with_tenant() diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 817623bf7f..fddfe2f4bc 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -131,7 +131,7 @@ class CompletionMessageStopApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - def post(self, app_model, task_id): + def post(self, app_model, task_id: str): if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") @@ -212,7 +212,7 @@ class ChatMessageStopApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - def post(self, app_model, task_id): + def post(self, app_model, task_id: str): if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 1978018366..6216a7bcbe 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,4 +1,5 @@ from typing import Literal +from uuid import UUID import sqlalchemy as sa from flask import abort, request @@ -164,10 +165,10 @@ class CompletionConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @edit_permission_required - def get(self, app_model, conversation_id): - conversation_id = str(conversation_id) + def get(self, app_model, conversation_id: UUID): + conversation_id_str = str(conversation_id) return ConversationMessageDetailResponse.model_validate( - _get_conversation(app_model, conversation_id), from_attributes=True + _get_conversation(app_model, conversation_id_str), from_attributes=True ).model_dump(mode="json") @console_ns.doc("delete_completion_conversation") @@ -181,12 +182,12 @@ class CompletionConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @edit_permission_required - def delete(self, app_model, conversation_id): + def delete(self, app_model, conversation_id: UUID): current_user, _ = current_account_with_tenant() - conversation_id = str(conversation_id) + conversation_id_str = str(conversation_id) try: - ConversationService.delete(app_model, conversation_id, current_user) + ConversationService.delete(app_model, conversation_id_str, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -317,10 +318,10 @@ class ChatConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @edit_permission_required - def get(self, app_model, conversation_id): - conversation_id = str(conversation_id) + def get(self, app_model, conversation_id: UUID): + conversation_id_str = str(conversation_id) return ConversationDetailResponse.model_validate( - _get_conversation(app_model, conversation_id), from_attributes=True + _get_conversation(app_model, conversation_id_str), from_attributes=True ).model_dump(mode="json") @console_ns.doc("delete_chat_conversation") @@ -334,12 +335,12 @@ class ChatConversationDetailApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required @edit_permission_required - def delete(self, app_model, conversation_id): + def delete(self, app_model, conversation_id: UUID): current_user, _ = current_account_with_tenant() - conversation_id = str(conversation_id) + conversation_id_str = str(conversation_id) try: - ConversationService.delete(app_model, conversation_id, current_user) + ConversationService.delete(app_model, conversation_id_str, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 13f6e098ba..a5259527ea 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,6 +1,7 @@ import json from datetime import datetime from typing import Any +from uuid import UUID from flask_restx import Resource from pydantic import BaseModel, Field, field_validator @@ -162,7 +163,7 @@ class AppMCPServerRefreshController(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, server_id): + def get(self, server_id: UUID): _, current_tenant_id = current_account_with_tenant() server = db.session.scalar( select(AppMCPServer) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 06afe15548..faa1e0fcda 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from typing import Literal +from uuid import UUID from flask import request from flask_restx import Resource @@ -336,13 +337,13 @@ class MessageSuggestedQuestionApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - def get(self, app_model, message_id): + def get(self, app_model, message_id: UUID): current_user, _ = current_account_with_tenant() - message_id = str(message_id) + message_id_str = str(message_id) try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER + app_model=app_model, message_id=message_id_str, user=current_user, invoke_from=InvokeFrom.DEBUGGER ) except MessageNotExistsError: raise NotFound("Message not found") @@ -417,10 +418,10 @@ class MessageApi(Resource): @login_required @account_initialization_required def get(self, app_model, message_id: str): - message_id = str(message_id) + message_id_str = str(message_id) message = db.session.scalar( - select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + select(Message).where(Message.id == message_id_str, Message.app_id == app_model.id).limit(1) ) if not message: diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 97d2003209..2d48b59de2 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,5 +1,6 @@ from datetime import UTC, datetime, timedelta from typing import Literal, cast +from uuid import UUID from flask import request from flask_restx import Resource @@ -367,14 +368,14 @@ class WorkflowRunDetailApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def get(self, app_model: App, run_id): + def get(self, app_model: App, run_id: UUID): """ Get workflow run detail """ - run_id = str(run_id) + run_id_str = str(run_id) workflow_run_service = WorkflowRunService() - workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id) + workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id_str) if workflow_run is None: raise NotFoundError("Workflow run not found") @@ -396,17 +397,17 @@ class WorkflowRunNodeExecutionListApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def get(self, app_model: App, run_id): + def get(self, app_model: App, run_id: UUID): """ Get workflow run node execution list """ - run_id = str(run_id) + run_id_str = str(run_id) workflow_run_service = WorkflowRunService() user = cast("Account | EndUser", current_user) node_executions = workflow_run_service.get_workflow_run_node_executions( app_model=app_model, - run_id=run_id, + run_id=run_id_str, user=user, ) diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 3f0650389f..7544c4dbdc 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,3 +1,5 @@ +from uuid import UUID + from flask_restx import Resource from pydantic import BaseModel, Field @@ -87,10 +89,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): @account_initialization_required @is_admin_or_owner_required @console_ns.response(204, "Binding deleted successfully") - def delete(self, binding_id): + def delete(self, binding_id: UUID): # The role of the current user in the table must be admin or owner _, current_tenant_id = current_account_with_tenant() - ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id) + ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id)) return "", 204 diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 3a3278ec9d..997dac9210 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -1,4 +1,5 @@ import logging +from uuid import UUID import httpx from flask import current_app, redirect, request @@ -158,16 +159,15 @@ class OAuthDataSourceSync(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider, binding_id): - provider = str(provider) - binding_id = str(binding_id) + def get(self, provider: str, binding_id: UUID): + binding_id_str = str(binding_id) OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: return {"error": "Invalid provider"}, 400 try: - oauth_provider.sync_data_source(binding_id) + oauth_provider.sync_data_source(binding_id_str) except httpx.HTTPStatusError as e: logger.exception( "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index f81adb0313..e3a6d20b2b 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,6 +1,7 @@ import json from collections.abc import Generator from typing import Any, Literal, cast +from uuid import UUID from flask import request from flask_restx import Resource, fields, marshal_with @@ -293,7 +294,7 @@ class DataSourceNotionApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__]) - def get(self, page_id, page_type): + def get(self, page_id: UUID, page_type: str): _, current_tenant_id = current_account_with_tenant() query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict()) @@ -306,11 +307,11 @@ class DataSourceNotionApi(Resource): plugin_id="langgenius/notion_datasource", ) - page_id = str(page_id) + page_id_str = str(page_id) extractor = NotionExtractor( notion_workspace_id="", - notion_obj_id=page_id, + notion_obj_id=page_id_str, notion_page_type=page_type, notion_access_token=credential.get("integration_secret"), tenant_id=current_tenant_id, @@ -367,7 +368,7 @@ class DataSourceNotionDatasetSyncApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def get(self, dataset_id): + def get(self, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -385,7 +386,7 @@ class DataSourceNotionDocumentSyncApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def get(self, dataset_id, document_id): + def get(self, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) dataset = DatasetService.get_dataset(dataset_id_str) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 3cc1e6b028..8e453f96dd 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,5 +1,6 @@ from datetime import datetime from typing import Any +from uuid import UUID from flask import request from flask_restx import Resource @@ -511,7 +512,7 @@ class DatasetApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: UUID): current_user, current_tenant_id = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -565,7 +566,7 @@ class DatasetApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id): + def patch(self, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -613,7 +614,7 @@ class DatasetApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Dataset deleted successfully") - def delete(self, dataset_id): + def delete(self, dataset_id: UUID): dataset_id_str = str(dataset_id) current_user, _ = current_account_with_tenant() @@ -643,7 +644,7 @@ class DatasetUseCheckApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) @@ -663,7 +664,7 @@ class DatasetQueryApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: UUID): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -803,7 +804,7 @@ class DatasetRelatedAppListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: UUID): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -839,11 +840,11 @@ class DatasetIndexingStatusApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: UUID): _, current_tenant_id = current_account_with_tenant() - dataset_id = str(dataset_id) + dataset_id_str = str(dataset_id) documents = db.session.scalars( - select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id) + select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id) ).all() documents_status = [] for document in documents: @@ -951,15 +952,15 @@ class DatasetApiDeleteApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def delete(self, api_key_id): + def delete(self, api_key_id: UUID): _, current_tenant_id = current_account_with_tenant() - api_key_id = str(api_key_id) + api_key_id_str = str(api_key_id) key = db.session.scalar( select(ApiToken) .where( ApiToken.tenant_id == current_tenant_id, ApiToken.type == self.resource_type, - ApiToken.id == api_key_id, + ApiToken.id == api_key_id_str, ) .limit(1) ) @@ -984,7 +985,7 @@ class DatasetEnableApiApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def post(self, dataset_id, status): + def post(self, dataset_id: UUID, status: str): dataset_id_str = str(dataset_id) DatasetService.update_dataset_api_status(dataset_id_str, status == "enable") @@ -1036,7 +1037,7 @@ class DatasetRetrievalSettingMockApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, vector_type): + def get(self, vector_type: str): return dump_response( RetrievalSettingResponse, _get_retrieval_methods_by_vector_type(vector_type, is_mock=True), @@ -1053,7 +1054,7 @@ class DatasetErrorDocs(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -1078,7 +1079,7 @@ class DatasetPermissionUserListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: UUID): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -1108,7 +1109,7 @@ class DatasetAutoDisableLogApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index fabd61e6b0..d387834e9b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -5,6 +5,7 @@ from collections.abc import Sequence from contextlib import ExitStack from datetime import datetime from typing import Any, Literal, cast +from uuid import UUID import sqlalchemy as sa from flask import request, send_file @@ -315,9 +316,9 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: UUID): current_user, current_tenant_id = current_account_with_tenant() - dataset_id = str(dataset_id) + dataset_id_str = str(dataset_id) raw_args = request.args.to_dict() param = DocumentDatasetListParam.model_validate(raw_args) page = param.page @@ -342,7 +343,7 @@ class DatasetDocumentListApi(Resource): ) except (ArgumentTypeError, ValueError, Exception): fetch = False - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") @@ -351,7 +352,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id) + query = select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id) if status: query = DocumentService.apply_display_status_filter(query, status) @@ -372,7 +373,7 @@ class DatasetDocumentListApi(Resource): sa.select( DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count") ) - .where(DocumentSegment.dataset_id == str(dataset_id)) + .where(DocumentSegment.dataset_id == dataset_id_str) .group_by(DocumentSegment.document_id) .subquery() ) @@ -444,11 +445,11 @@ class DatasetDocumentListApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__]) - def post(self, dataset_id): + def post(self, dataset_id: UUID): current_user, _ = current_account_with_tenant() - dataset_id = str(dataset_id) + dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") @@ -472,7 +473,7 @@ class DatasetDocumentListApi(Resource): try: documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -490,9 +491,9 @@ class DatasetDocumentListApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Documents deleted successfully") - def delete(self, dataset_id): - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + def delete(self, dataset_id: UUID): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") # check user's model setting @@ -582,11 +583,11 @@ class DocumentIndexingEstimateApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, document_id): + def get(self, dataset_id: UUID, document_id: UUID): _, current_tenant_id = current_account_with_tenant() - dataset_id = str(dataset_id) - document_id = str(document_id) - document = self.get_document(dataset_id, document_id) + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) + document = self.get_document(dataset_id_str, document_id_str) if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}: raise DocumentAlreadyFinishedError() @@ -624,7 +625,7 @@ class DocumentIndexingEstimateApi(DocumentResource): data_process_rule_dict, document.doc_form, "English", - dataset_id, + dataset_id_str, ) return estimate_response.model_dump(), 200 except LLMBadRequestError: @@ -647,11 +648,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, batch): + def get(self, dataset_id: UUID, batch: str): _, current_tenant_id = current_account_with_tenant() - dataset_id = str(dataset_id) - batch = str(batch) - documents = self.get_batch_documents(dataset_id, batch) + dataset_id_str = str(dataset_id) + documents = self.get_batch_documents(dataset_id_str, batch) if not documents: return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 data_process_rule = documents[0].dataset_process_rule @@ -725,7 +725,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): data_process_rule_dict, document.doc_form, "English", - dataset_id, + dataset_id_str, ) return response.model_dump(), 200 except LLMBadRequestError: @@ -745,10 +745,9 @@ class DocumentBatchIndexingStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, batch): - dataset_id = str(dataset_id) - batch = str(batch) - documents = self.get_batch_documents(dataset_id, batch) + def get(self, dataset_id: UUID, batch: str): + dataset_id_str = str(dataset_id) + documents = self.get_batch_documents(dataset_id_str, batch) documents_status = [] for document in documents: completed_segments = ( @@ -800,16 +799,16 @@ class DocumentIndexingStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, document_id): - dataset_id = str(dataset_id) - document_id = str(document_id) - document = self.get_document(dataset_id, document_id) + def get(self, dataset_id: UUID, document_id: UUID): + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) + document = self.get_document(dataset_id_str, document_id_str) completed_segments = ( db.session.scalar( select(func.count(DocumentSegment.id)).where( DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), + DocumentSegment.document_id == str(document_id_str), DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) ) @@ -818,7 +817,7 @@ class DocumentIndexingStatusApi(DocumentResource): total_segments = ( db.session.scalar( select(func.count(DocumentSegment.id)).where( - DocumentSegment.document_id == str(document_id), + DocumentSegment.document_id == str(document_id_str), DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) ) @@ -861,10 +860,10 @@ class DocumentApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, document_id): - dataset_id = str(dataset_id) - document_id = str(document_id) - document = self.get_document(dataset_id, document_id) + def get(self, dataset_id: UUID, document_id: UUID): + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) + document = self.get_document(dataset_id_str, document_id_str) metadata = request.args.get("metadata", "all") if metadata not in self.METADATA_CHOICES: @@ -873,7 +872,7 @@ class DocumentApi(DocumentResource): if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": - dataset_process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id_str) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} response = { "id": document.id, @@ -907,7 +906,7 @@ class DocumentApi(DocumentResource): "need_summary": document.need_summary if document.need_summary is not None else False, } else: - dataset_process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id_str) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} response = { "id": document.id, @@ -950,16 +949,16 @@ class DocumentApi(DocumentResource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Document deleted successfully") - def delete(self, dataset_id, document_id): - dataset_id = str(dataset_id) - document_id = str(document_id) - dataset = DatasetService.get_dataset(dataset_id) + def delete(self, dataset_id: UUID, document_id: UUID): + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) + dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) - document = self.get_document(dataset_id, document_id) + document = self.get_document(dataset_id_str, document_id_str) try: DocumentService.delete_document(document) @@ -1003,10 +1002,10 @@ class DocumentBatchDownloadZipApi(DocumentResource): payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {}) current_user, current_tenant_id = current_account_with_tenant() - dataset_id = str(dataset_id) + dataset_id_str = str(dataset_id) document_ids: list[str] = [str(document_id) for document_id in payload.document_ids] upload_files, download_name = DocumentService.prepare_document_batch_download_zip( - dataset_id=dataset_id, + dataset_id=dataset_id_str, document_ids=document_ids, tenant_id=current_tenant_id, current_user=current_user, @@ -1044,11 +1043,11 @@ class DocumentProcessingApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]): + def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]): current_user, _ = current_account_with_tenant() - dataset_id = str(dataset_id) - document_id = str(document_id) - document = self.get_document(dataset_id, document_id) + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) + document = self.get_document(dataset_id_str, document_id_str) # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: @@ -1092,11 +1091,11 @@ class DocumentMetadataApi(DocumentResource): @setup_required @login_required @account_initialization_required - def put(self, dataset_id, document_id): + def put(self, dataset_id: UUID, document_id: UUID): current_user, _ = current_account_with_tenant() - dataset_id = str(dataset_id) - document_id = str(document_id) - document = self.get_document(dataset_id, document_id) + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) + document = self.get_document(dataset_id_str, document_id_str) req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {}) @@ -1141,10 +1140,10 @@ class DocumentStatusApi(DocumentResource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): + def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]): current_user, _ = current_account_with_tenant() - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") @@ -1179,16 +1178,16 @@ class DocumentPauseApi(DocumentResource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Document paused successfully") - def patch(self, dataset_id, document_id): + def patch(self, dataset_id: UUID, document_id: UUID): """pause document.""" - dataset_id = str(dataset_id) - document_id = str(document_id) + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) # 404 if document not found if document is None: @@ -1214,14 +1213,14 @@ class DocumentRecoverApi(DocumentResource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Document resumed successfully") - def patch(self, dataset_id, document_id): + def patch(self, dataset_id: UUID, document_id: UUID): """recover document.""" - dataset_id = str(dataset_id) - document_id = str(document_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) # 404 if document not found if document is None: @@ -1247,11 +1246,11 @@ class DocumentRetryApi(DocumentResource): @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[DocumentRetryPayload.__name__]) @console_ns.response(204, "Documents retry started successfully") - def post(self, dataset_id): + def post(self, dataset_id: UUID): """retry document.""" payload = DocumentRetryPayload.model_validate(console_ns.payload or {}) - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) retry_documents = [] if not dataset: raise NotFound("Dataset not found.") @@ -1277,7 +1276,7 @@ class DocumentRetryApi(DocumentResource): logger.exception("Failed to retry document, document id: %s", document_id) continue # retry document - DocumentService.retry_document(dataset_id, retry_documents) + DocumentService.retry_document(dataset_id_str, retry_documents) return "", 204 @@ -1289,7 +1288,7 @@ class DocumentRenameApi(DocumentResource): @account_initialization_required @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__]) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) - def post(self, dataset_id, document_id): + def post(self, dataset_id: UUID, document_id: UUID): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator current_user, _ = current_account_with_tenant() if not current_user.is_dataset_editor: @@ -1301,7 +1300,7 @@ class DocumentRenameApi(DocumentResource): payload = DocumentRenamePayload.model_validate(console_ns.payload or {}) try: - document = DocumentService.rename_document(dataset_id, document_id, payload.name) + document = DocumentService.rename_document(str(dataset_id), str(document_id), payload.name) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") @@ -1314,15 +1313,15 @@ class WebsiteDocumentSyncApi(DocumentResource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def get(self, dataset_id, document_id): + def get(self, dataset_id: UUID, document_id: UUID): """sync website document.""" _, current_tenant_id = current_account_with_tenant() - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") - document_id = str(document_id) - document = DocumentService.get_document(dataset.id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset.id, document_id_str) if not document: raise NotFound("Document not found.") if document.tenant_id != current_tenant_id: @@ -1333,7 +1332,7 @@ class WebsiteDocumentSyncApi(DocumentResource): if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() # sync document - DocumentService.sync_website_document(dataset_id, document) + DocumentService.sync_website_document(dataset_id_str, document) return {"result": "success"}, 200 @@ -1343,19 +1342,19 @@ class DocumentPipelineExecutionLogApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, document_id): - dataset_id = str(dataset_id) - document_id = str(document_id) + def get(self, dataset_id: UUID, document_id: UUID): + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) if not document: raise NotFound("Document not found.") log = db.session.scalar( select(DocumentPipelineExecutionLog) - .where(DocumentPipelineExecutionLog.document_id == document_id) + .where(DocumentPipelineExecutionLog.document_id == document_id_str) .order_by(DocumentPipelineExecutionLog.created_at.desc()) .limit(1) ) @@ -1392,7 +1391,7 @@ class DocumentGenerateSummaryApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def post(self, dataset_id): + def post(self, dataset_id: UUID): """ Generate summary index for specified documents. @@ -1401,10 +1400,10 @@ class DocumentGenerateSummaryApi(Resource): then asynchronously generates summary indexes for the provided documents. """ current_user, _ = current_account_with_tenant() - dataset_id = str(dataset_id) + dataset_id_str = str(dataset_id) # Get dataset - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") @@ -1438,7 +1437,7 @@ class DocumentGenerateSummaryApi(Resource): raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.") # Verify all documents exist and belong to the dataset - documents = DocumentService.get_documents_by_ids(dataset_id, document_list) + documents = DocumentService.get_documents_by_ids(dataset_id_str, document_list) if len(documents) != len(document_list): found_ids = {doc.id for doc in documents} @@ -1452,7 +1451,7 @@ class DocumentGenerateSummaryApi(Resource): if documents_to_update: document_ids_to_update = [str(doc.id) for doc in documents_to_update] DocumentService.update_documents_need_summary( - dataset_id=dataset_id, + dataset_id=dataset_id_str, document_ids=document_ids_to_update, need_summary=True, ) @@ -1465,11 +1464,11 @@ class DocumentGenerateSummaryApi(Resource): continue # Dispatch async task - generate_summary_index_task.delay(dataset_id, document.id) + generate_summary_index_task.delay(dataset_id_str, document.id) logger.info( "Dispatched summary generation task for document %s in dataset %s", document.id, - dataset_id, + dataset_id_str, ) return {"result": "success"}, 200 @@ -1485,7 +1484,7 @@ class DocumentSummaryStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, document_id): + def get(self, dataset_id: UUID, document_id: UUID): """ Get summary index generation status for a document. @@ -1499,11 +1498,11 @@ class DocumentSummaryStatusApi(DocumentResource): - summaries: List of summary records with status and content preview """ current_user, _ = current_account_with_tenant() - dataset_id = str(dataset_id) - document_id = str(document_id) + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) # Get dataset - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") @@ -1517,8 +1516,8 @@ class DocumentSummaryStatusApi(DocumentResource): from services.summary_index_service import SummaryIndexService result = SummaryIndexService.get_document_summary_status_detail( - document_id=document_id, - dataset_id=dataset_id, + document_id=document_id_str, + dataset_id=dataset_id_str, ) return result, 200 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 1d3bc96c1b..38ad7dfdd1 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,4 +1,6 @@ import uuid +from typing import Literal +from uuid import UUID from flask import request from flask_restx import Resource, marshal @@ -113,12 +115,12 @@ class DatasetDocumentSegmentListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, document_id): + def get(self, dataset_id: UUID, document_id: UUID): current_user, current_tenant_id = current_account_with_tenant() - dataset_id = str(dataset_id) - document_id = str(document_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") @@ -127,7 +129,7 @@ class DatasetDocumentSegmentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - document = DocumentService.get_document(dataset_id, document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") @@ -148,7 +150,7 @@ class DatasetDocumentSegmentListApi(Resource): query = ( select(DocumentSegment) .where( - DocumentSegment.document_id == str(document_id), + DocumentSegment.document_id == document_id_str, DocumentSegment.tenant_id == current_tenant_id, ) .order_by(DocumentSegment.position.asc()) @@ -201,7 +203,9 @@ class DatasetDocumentSegmentListApi(Resource): if segment_ids: from services.summary_index_service import SummaryIndexService - summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) + summary_records = SummaryIndexService.get_segments_summaries( + segment_ids=segment_ids, dataset_id=dataset_id_str + ) # Only include enabled summaries (already filtered by service) summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()} @@ -226,19 +230,19 @@ class DatasetDocumentSegmentListApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Segments deleted successfully") - def delete(self, dataset_id, document_id): + def delete(self, dataset_id: UUID, document_id: UUID): current_user, _ = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") segment_ids = request.args.getlist("segment_id") @@ -262,15 +266,15 @@ class DatasetDocumentSegmentApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def patch(self, dataset_id, document_id, action): + def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]): current_user, current_tenant_id = current_account_with_tenant() - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") # check user's model setting @@ -321,17 +325,17 @@ class DatasetDocumentSegmentAddApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[SegmentCreatePayload.__name__]) - def post(self, dataset_id, document_id): + def post(self, dataset_id: UUID, document_id: UUID): current_user, current_tenant_id = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") if not current_user.is_dataset_editor: @@ -361,7 +365,7 @@ class DatasetDocumentSegmentAddApi(Resource): payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) segment = SegmentService.create_segment(payload_dict, document, dataset) - return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200 @console_ns.route("/datasets//documents//segments/") @@ -372,19 +376,19 @@ class DatasetDocumentSegmentUpdateApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__]) - def patch(self, dataset_id, document_id, segment_id): + def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): current_user, current_tenant_id = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: @@ -404,10 +408,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment - segment_id = str(segment_id) + segment_id_str = str(segment_id) segment = db.session.scalar( select(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) + .where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id) .limit(1) ) if not segment: @@ -428,33 +432,33 @@ class DatasetDocumentSegmentUpdateApi(Resource): segment = SegmentService.update_segment( SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset ) - return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200 @setup_required @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Segment deleted successfully") - def delete(self, dataset_id, document_id, segment_id): + def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): current_user, current_tenant_id = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) + segment_id_str = str(segment_id) segment = db.session.scalar( select(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) + .where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id) .limit(1) ) if not segment: @@ -483,17 +487,17 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[BatchImportPayload.__name__]) - def post(self, dataset_id, document_id): + def post(self, dataset_id: UUID, document_id: UUID): current_user, current_tenant_id = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") @@ -517,8 +521,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource): batch_create_segment_to_index_task.delay( str(job_id), upload_file_id, - dataset_id, - document_id, + dataset_id_str, + document_id_str, current_tenant_id, current_user.id, ) @@ -530,7 +534,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, job_id=None, dataset_id=None, document_id=None): + def get(self, job_id=None, dataset_id: UUID | None = None, document_id: UUID | None = None): if job_id is None: raise NotFound("The job does not exist.") job_id = str(job_id) @@ -551,24 +555,24 @@ class ChildChunkAddApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__]) - def post(self, dataset_id, document_id, segment_id): + def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): current_user, current_tenant_id = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) + segment_id_str = str(segment_id) segment = db.session.scalar( select(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) + .where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id) .limit(1) ) if not segment: @@ -606,26 +610,26 @@ class ChildChunkAddApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, document_id, segment_id): + def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) + segment_id_str = str(segment_id) segment = db.session.scalar( select(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) + .where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id) .limit(1) ) if not segment: @@ -642,7 +646,9 @@ class ChildChunkAddApi(Resource): limit = min(args.limit, 100) keyword = args.keyword - child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) + child_chunks = SegmentService.get_child_chunks( + segment_id_str, document_id_str, dataset_id_str, page, limit, keyword + ) return { "data": marshal(child_chunks.items, child_chunk_fields), "total": child_chunks.total, @@ -656,26 +662,26 @@ class ChildChunkAddApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, document_id, segment_id): + def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): current_user, current_tenant_id = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) + segment_id_str = str(segment_id) segment = db.session.scalar( select(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) + .where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id) .limit(1) ) if not segment: @@ -705,39 +711,39 @@ class ChildChunkUpdateApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Child chunk deleted successfully") - def delete(self, dataset_id, document_id, segment_id, child_chunk_id): + def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID): current_user, current_tenant_id = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) + segment_id_str = str(segment_id) segment = db.session.scalar( select(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) + .where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id) .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk - child_chunk_id = str(child_chunk_id) + child_chunk_id_str = str(child_chunk_id) child_chunk = db.session.scalar( select(ChildChunk) .where( - ChildChunk.id == str(child_chunk_id), + ChildChunk.id == str(child_chunk_id_str), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, - ChildChunk.document_id == document_id, + ChildChunk.document_id == document_id_str, ) .limit(1) ) @@ -762,39 +768,39 @@ class ChildChunkUpdateApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__]) - def patch(self, dataset_id, document_id, segment_id, child_chunk_id): + def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID): current_user, current_tenant_id = current_account_with_tenant() # check dataset - dataset_id = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset_id, document_id) + document_id_str = str(document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) + segment_id_str = str(segment_id) segment = db.session.scalar( select(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) + .where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id) .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk - child_chunk_id = str(child_chunk_id) + child_chunk_id_str = str(child_chunk_id) child_chunk = db.session.scalar( select(ChildChunk) .where( - ChildChunk.id == str(child_chunk_id), + ChildChunk.id == str(child_chunk_id_str), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, - ChildChunk.document_id == document_id, + ChildChunk.document_id == document_id_str, ) .limit(1) ) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index d1cdc15d0b..d6cc176a39 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,3 +1,5 @@ +from uuid import UUID + from flask import request from flask_restx import Resource, fields, marshal from pydantic import BaseModel, Field @@ -175,11 +177,11 @@ class ExternalApiTemplateApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, external_knowledge_api_id): + def get(self, external_knowledge_api_id: UUID): _, current_tenant_id = current_account_with_tenant() - external_knowledge_api_id = str(external_knowledge_api_id) + external_knowledge_api_id_str = str(external_knowledge_api_id) external_knowledge_api = ExternalDatasetService.get_external_knowledge_api( - external_knowledge_api_id, current_tenant_id + external_knowledge_api_id_str, current_tenant_id ) if external_knowledge_api is None: raise NotFound("API template not found.") @@ -190,9 +192,9 @@ class ExternalApiTemplateApi(Resource): @login_required @account_initialization_required @console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__]) - def patch(self, external_knowledge_api_id): + def patch(self, external_knowledge_api_id: UUID): current_user, current_tenant_id = current_account_with_tenant() - external_knowledge_api_id = str(external_knowledge_api_id) + external_knowledge_api_id_str = str(external_knowledge_api_id) payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {}) ExternalDatasetService.validate_api_list(payload.settings) @@ -200,7 +202,7 @@ class ExternalApiTemplateApi(Resource): external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( tenant_id=current_tenant_id, user_id=current_user.id, - external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_api_id=external_knowledge_api_id_str, args=payload.model_dump(), ) @@ -210,14 +212,14 @@ class ExternalApiTemplateApi(Resource): @login_required @account_initialization_required @console_ns.response(204, "External knowledge API deleted successfully") - def delete(self, external_knowledge_api_id): + def delete(self, external_knowledge_api_id: UUID): current_user, current_tenant_id = current_account_with_tenant() - external_knowledge_api_id = str(external_knowledge_api_id) + external_knowledge_api_id_str = str(external_knowledge_api_id) if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() - ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id) + ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id_str) return "", 204 @@ -230,12 +232,12 @@ class ExternalApiUseCheckApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, external_knowledge_api_id): + def get(self, external_knowledge_api_id: UUID): _, current_tenant_id = current_account_with_tenant() - external_knowledge_api_id = str(external_knowledge_api_id) + external_knowledge_api_id_str = str(external_knowledge_api_id) external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check( - external_knowledge_api_id, current_tenant_id + external_knowledge_api_id_str, current_tenant_id ) return {"is_using": external_knowledge_api_is_using, "count": count}, 200 @@ -286,7 +288,7 @@ class ExternalKnowledgeHitTestingApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, dataset_id): + def post(self, dataset_id: UUID): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 8758f983ee..110a2e16f5 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -2,6 +2,7 @@ from __future__ import annotations from datetime import datetime from typing import Any +from uuid import UUID from flask_restx import Resource from pydantic import Field, field_validator @@ -118,7 +119,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def post(self, dataset_id): + def post(self, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 4de5f32fb8..cf516aa63b 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,4 +1,5 @@ from typing import Literal +from uuid import UUID from flask_restx import Resource from werkzeug.exceptions import NotFound @@ -42,7 +43,7 @@ class DatasetMetadataCreateApi(Resource): @enterprise_license_required @console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__]) @console_ns.expect(console_ns.models[MetadataArgs.__name__]) - def post(self, dataset_id): + def post(self, dataset_id: UUID): current_user, _ = current_account_with_tenant() metadata_args = MetadataArgs.model_validate(console_ns.payload or {}) @@ -62,7 +63,7 @@ class DatasetMetadataCreateApi(Resource): @console_ns.response( 200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__] ) - def get(self, dataset_id): + def get(self, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -79,7 +80,7 @@ class DatasetMetadataApi(Resource): @enterprise_license_required @console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__]) @console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__]) - def patch(self, dataset_id, metadata_id): + def patch(self, dataset_id: UUID, metadata_id: UUID): current_user, _ = current_account_with_tenant() payload = MetadataUpdatePayload.model_validate(console_ns.payload or {}) name = payload.name @@ -99,7 +100,7 @@ class DatasetMetadataApi(Resource): @account_initialization_required @enterprise_license_required @console_ns.response(204, "Metadata deleted successfully") - def delete(self, dataset_id, metadata_id): + def delete(self, dataset_id: UUID, metadata_id: UUID): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -136,7 +137,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): @account_initialization_required @enterprise_license_required @console_ns.response(204, "Action completed successfully") - def post(self, dataset_id, action: Literal["enable", "disable"]): + def post(self, dataset_id: UUID, action: Literal["enable", "disable"]): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -164,7 +165,7 @@ class DocumentMetadataEditApi(Resource): 204, "Documents metadata updated successfully", ) - def post(self, dataset_id): + def post(self, dataset_id: UUID): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index a6ca0689d0..39c8aaa451 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,6 +1,6 @@ from flask_restx import Resource, marshal from pydantic import BaseModel -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden import services @@ -54,12 +54,13 @@ class CreateRagPipelineDatasetApi(Resource): yaml_content=payload.yaml_content, ) try: - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( tenant_id=current_tenant_id, rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, ) + session.commit() if rag_pipeline_dataset_create_entity.permission == "partial_members": DatasetPermissionService.update_partial_member_list( current_tenant_id, diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index aa27458176..3ae5d308c2 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -1,7 +1,7 @@ from flask import request from flask_restx import Resource, fields, marshal_with # type: ignore from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns @@ -67,10 +67,12 @@ class RagPipelineImportApi(Resource): current_user, _ = current_account_with_tenant() payload = RagPipelineImportPayload.model_validate(console_ns.payload or {}) - # Create service with session - with sessionmaker(db.engine).begin() as session: + # Use a plain Session so that caught exceptions inside the service + # (which return FAILED status instead of re-raising) do not leave the + # transaction in a closed state that a .begin() context manager cannot + # handle. See app_import.py for the canonical pattern. + with Session(db.engine, expire_on_commit=False) as session: import_service = RagPipelineDslService(session) - # Import app account = current_user result = import_service.import_rag_pipeline( account=account, @@ -80,6 +82,10 @@ class RagPipelineImportApi(Resource): pipeline_id=payload.pipeline_id, dataset_name=payload.name, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() # Return appropriate status code based on result status = result.status @@ -99,15 +105,17 @@ class RagPipelineImportConfirmApi(Resource): @account_initialization_required @edit_permission_required @marshal_with(pipeline_import_model) - def post(self, import_id): + def post(self, import_id: str): current_user, _ = current_account_with_tenant() - # Create service with session - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = RagPipelineDslService(session) - # Confirm import account = current_user result = import_service.confirm_import(import_id=import_id, account=account) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() # Return appropriate status code based on result if result.status == ImportStatus.FAILED: @@ -124,7 +132,7 @@ class RagPipelineImportCheckDependenciesApi(Resource): @edit_permission_required @marshal_with(pipeline_import_check_dependencies_model) def get(self, pipeline: Pipeline): - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = RagPipelineDslService(session) result = import_service.check_dependencies(pipeline=pipeline) @@ -142,7 +150,7 @@ class RagPipelineExportApi(Resource): # Add include_secret params query = IncludeSecretQuery.model_validate(request.args.to_dict()) - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: export_service = RagPipelineDslService(session) result = export_service.export_rag_pipeline_dsl( pipeline=pipeline, include_secret=query.include_secret == "true" diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 25e8b060b8..a7727513df 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -1,6 +1,7 @@ import json import logging from typing import Any, Literal, cast +from uuid import UUID from flask import abort, request from flask_restx import Resource @@ -875,14 +876,14 @@ class RagPipelineWorkflowRunDetailApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - def get(self, pipeline: Pipeline, run_id): + def get(self, pipeline: Pipeline, run_id: UUID): """ Get workflow run detail """ - run_id = str(run_id) + run_id_str = str(run_id) rag_pipeline_service = RagPipelineService() - workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id) + workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id_str) if workflow_run is None: raise NotFound("Workflow run not found") @@ -904,13 +905,13 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): """ Get workflow run node execution list """ - run_id = str(run_id) + run_id_str = str(run_id) rag_pipeline_service = RagPipelineService() user = cast("Account | EndUser", current_user) node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions( pipeline=pipeline, - run_id=run_id, + run_id=run_id_str, user=user, ) @@ -960,15 +961,15 @@ class RagPipelineTransformApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, dataset_id: str): + def post(self, dataset_id: UUID): current_user, _ = current_account_with_tenant() if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() - dataset_id = str(dataset_id) + dataset_id_str = str(dataset_id) rag_pipeline_transform_service = RagPipelineTransformService() - result = rag_pipeline_transform_service.transform_dataset(dataset_id) + result = rag_pipeline_transform_service.transform_dataset(dataset_id_str) return result diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 2a699efb1d..72e2d923da 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -133,7 +133,7 @@ class CompletionApi(InstalledAppResource): ) class CompletionStopApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def post(self, installed_app, task_id): + def post(self, installed_app, task_id: str): app_model = installed_app.app if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() @@ -209,7 +209,7 @@ class ChatApi(InstalledAppResource): ) class ChatStopApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def post(self, installed_app, task_id): + def post(self, installed_app, task_id: str): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index ae32571219..a3ae59aaf7 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,4 +1,5 @@ from typing import Any +from uuid import UUID from flask import request from pydantic import BaseModel, Field, TypeAdapter @@ -91,7 +92,7 @@ class ConversationListApi(InstalledAppResource): ) class ConversationApi(InstalledAppResource): @console_ns.response(204, "Conversation deleted successfully") - def delete(self, installed_app, c_id): + def delete(self, installed_app, c_id: UUID): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -114,7 +115,7 @@ class ConversationApi(InstalledAppResource): ) class ConversationRenameApi(InstalledAppResource): @console_ns.expect(console_ns.models[ConversationRenamePayload.__name__]) - def post(self, installed_app, c_id): + def post(self, installed_app, c_id: UUID): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -145,7 +146,7 @@ class ConversationRenameApi(InstalledAppResource): ) class ConversationPinApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__]) - def patch(self, installed_app, c_id): + def patch(self, installed_app, c_id: UUID): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -169,7 +170,7 @@ class ConversationPinApi(InstalledAppResource): ) class ConversationUnPinApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__]) - def patch(self, installed_app, c_id): + def patch(self, installed_app, c_id: UUID): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 0d25fbb66d..c6930a76cb 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,5 +1,6 @@ import logging from typing import Literal +from uuid import UUID from flask import request from pydantic import BaseModel, TypeAdapter @@ -95,18 +96,18 @@ class MessageListApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource): @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__]) @console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__]) - def post(self, installed_app, message_id): + def post(self, installed_app, message_id: UUID): current_user, _ = current_account_with_tenant() app_model = installed_app.app - message_id = str(message_id) + message_id_str = str(message_id) payload = MessageFeedbackPayload.model_validate(console_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, - message_id=message_id, + message_id=message_id_str, user=current_user, rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, @@ -123,13 +124,13 @@ class MessageFeedbackApi(InstalledAppResource): ) class MessageMoreLikeThisApi(InstalledAppResource): @console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__]) - def get(self, installed_app, message_id): + def get(self, installed_app, message_id: UUID): current_user, _ = current_account_with_tenant() app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - message_id = str(message_id) + message_id_str = str(message_id) args = MoreLikeThisQuery.model_validate(request.args.to_dict()) @@ -139,7 +140,7 @@ class MessageMoreLikeThisApi(InstalledAppResource): response = AppGenerateService.generate_more_like_this( app_model=app_model, user=current_user, - message_id=message_id, + message_id=message_id_str, invoke_from=InvokeFrom.EXPLORE, streaming=streaming, ) @@ -169,18 +170,18 @@ class MessageMoreLikeThisApi(InstalledAppResource): ) class MessageSuggestedQuestionApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__]) - def get(self, installed_app, message_id): + def get(self, installed_app, message_id: UUID): current_user, _ = current_account_with_tenant() app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - message_id = str(message_id) + message_id_str = str(message_id) try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE + app_model=app_model, user=current_user, message_id=message_id_str, invoke_from=InvokeFrom.EXPLORE ) except MessageNotExistsError: raise NotFound("Message not found") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 09f214bd2b..224715d255 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,3 +1,5 @@ +from uuid import UUID + from flask import request from pydantic import TypeAdapter from werkzeug.exceptions import NotFound @@ -65,15 +67,15 @@ class SavedMessageListApi(InstalledAppResource): ) class SavedMessageApi(InstalledAppResource): @console_ns.response(204, "Saved message deleted successfully") - def delete(self, installed_app, message_id): + def delete(self, installed_app, message_id: UUID): current_user, _ = current_account_with_tenant() app_model = installed_app.app - message_id = str(message_id) + message_id_str = str(message_id) if app_model.mode != "completion": raise NotCompletionAppError() - SavedMessageService.delete(app_model, current_user, message_id) + SavedMessageService.delete(app_model, current_user, message_id_str) return "", 204 diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 35e62e3c7e..029e2e7f0d 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,5 +1,6 @@ from datetime import datetime from typing import Any +from uuid import UUID from flask import request from flask_restx import Resource @@ -152,7 +153,7 @@ class APIBasedExtensionDetailAPI(Resource): @setup_required @login_required @account_initialization_required - def get(self, id): + def get(self, id: UUID): api_based_extension_id = str(id) _, tenant_id = current_account_with_tenant() @@ -168,7 +169,7 @@ class APIBasedExtensionDetailAPI(Resource): @setup_required @login_required @account_initialization_required - def post(self, id): + def post(self, id: UUID): api_based_extension_id = str(id) _, current_tenant_id = current_account_with_tenant() @@ -196,7 +197,7 @@ class APIBasedExtensionDetailAPI(Resource): @setup_required @login_required @account_initialization_required - def delete(self, id): + def delete(self, id: UUID): api_based_extension_id = str(id) _, current_tenant_id = current_account_with_tenant() diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 6811c8ce77..499a623872 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,4 +1,5 @@ from typing import Literal +from uuid import UUID from flask import request from flask_restx import Resource @@ -106,10 +107,10 @@ class FilePreviewApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__]) - def get(self, file_id): - file_id = str(file_id) + def get(self, file_id: UUID): + file_id_str = str(file_id) _, tenant_id = current_account_with_tenant() - text = FileService(db.engine).get_file_preview(file_id, tenant_id) + text = FileService(db.engine).get_file_preview(file_id_str, tenant_id) return {"content": text} diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 199a05d9c2..a37e56e2b8 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,4 +1,5 @@ from typing import Literal +from uuid import UUID from flask import request from flask_restx import Resource @@ -131,17 +132,17 @@ class TagUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, tag_id): + def patch(self, tag_id: UUID): current_user, _ = current_account_with_tenant() - tag_id = str(tag_id) + tag_id_str = str(tag_id) # The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {}) - tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id) + tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str) - binding_count = TagService.get_tag_binding_count(tag_id) + binding_count = TagService.get_tag_binding_count(tag_id_str) response = TagResponse.model_validate( {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} @@ -154,10 +155,10 @@ class TagUpdateDeleteApi(Resource): @account_initialization_required @edit_permission_required @console_ns.response(204, "Tag deleted successfully") - def delete(self, tag_id): - tag_id = str(tag_id) + def delete(self, tag_id: UUID): + tag_id_str = str(tag_id) - TagService.delete_tag(tag_id) + TagService.delete_tag(tag_id_str) return "", 204 diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 2be91626e7..910e07e14d 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,4 +1,5 @@ from urllib import parse +from uuid import UUID from flask import abort, request from flask_restx import Resource @@ -175,7 +176,7 @@ class MemberCancelInviteApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, member_id): + def delete(self, member_id: UUID): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") @@ -208,7 +209,7 @@ class MemberUpdateRoleApi(Resource): @setup_required @login_required @account_initialization_required - def put(self, member_id): + def put(self, member_id: UUID): payload = console_ns.payload or {} args = MemberRoleUpdatePayload.model_validate(payload) new_role = args.role @@ -351,7 +352,7 @@ class OwnerTransfer(Resource): @login_required @account_initialization_required @is_allow_transfer_owner - def post(self, member_id): + def post(self, member_id: UUID): payload = console_ns.payload or {} args = OwnerTransferPayload.model_validate(payload) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index c23207e402..21c6b1c7ce 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -532,7 +532,7 @@ class ModelProviderAvailableModelApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, model_type): + def get(self, model_type: str): _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 9e5e766f45..c41bf99563 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -15,6 +15,7 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError +from core.plugin.plugin_service import PluginService from fields.base import ResponseModel from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required @@ -22,7 +23,6 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_permission_service import PluginPermissionService -from services.plugin.plugin_service import PluginService class ParserList(BaseModel): diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index be7886e831..0d4b89096a 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,4 +1,5 @@ from urllib.parse import quote +from uuid import UUID from flask import Response, request from flask_restx import Resource @@ -49,8 +50,8 @@ class ImagePreviewApi(Resource): 415: "Unsupported file type", } ) - def get(self, file_id): - file_id = str(file_id) + def get(self, file_id: UUID): + file_id_str = str(file_id) args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) timestamp = args.timestamp @@ -59,7 +60,7 @@ class ImagePreviewApi(Resource): try: generator, mimetype = FileService(db.engine).get_image_preview( - file_id=file_id, + file_id=file_id_str, timestamp=timestamp, nonce=nonce, sign=sign, @@ -91,14 +92,14 @@ class FilePreviewApi(Resource): 415: "Unsupported file type", } ) - def get(self, file_id): - file_id = str(file_id) + def get(self, file_id: UUID): + file_id_str = str(file_id) args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) try: generator, upload_file = FileService(db.engine).get_file_generator_by_file_id( - file_id=file_id, + file_id=file_id_str, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign, @@ -159,10 +160,10 @@ class WorkspaceWebappLogoApi(Resource): 415: "Unsupported file type", } ) - def get(self, workspace_id): - workspace_id = str(workspace_id) + def get(self, workspace_id: UUID): + workspace_id_str = str(workspace_id) - custom_config = TenantService.get_custom_config(workspace_id) + custom_config = TenantService.get_custom_config(workspace_id_str) webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None if not webapp_logo_file_id: diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 8ae16ce7f4..ef47485d80 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,4 +1,5 @@ from urllib.parse import quote +from uuid import UUID from flask import Response, request from flask_restx import Resource @@ -45,17 +46,19 @@ class ToolFileApi(Resource): 415: "Unsupported file type", } ) - def get(self, file_id, extension): - file_id = str(file_id) + def get(self, file_id: UUID, extension: str): + file_id_str = str(file_id) args = ToolFileQuery.model_validate(request.args.to_dict()) - if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign): + if not verify_tool_file_signature( + file_id=file_id_str, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign + ): raise Forbidden("Invalid request.") try: tool_file_manager = ToolFileManager() stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id( - file_id, + file_id_str, ) if not stream or not tool_file: diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 00bb9aa463..9f6b1cf52e 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,4 +1,5 @@ from typing import Literal +from uuid import UUID from flask import request from flask_restx import Resource @@ -78,10 +79,10 @@ class AnnotationReplyActionStatusApi(Resource): } ) @validate_app_token - def get(self, app_model: App, job_id, action): + def get(self, app_model: App, job_id: UUID, action: str): """Get the status of an annotation reply action job.""" - job_id = str(job_id) - app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" + job_id_str = str(job_id) + app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}" cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job does not exist.") @@ -89,10 +90,10 @@ class AnnotationReplyActionStatusApi(Resource): job_status = cache_result.decode() error_msg = "" if job_status == "error": - app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}" + app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}" error_msg = redis_client.get(app_annotation_error_key).decode() - return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 + return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200 @service_api_ns.route("/apps/annotations") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 64b2038f9c..cd247c7a8e 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,6 @@ from datetime import datetime from typing import Any, Literal +from uuid import UUID from flask import request from flask_restx import Resource @@ -195,7 +196,7 @@ class ConversationDetailApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - def delete(self, app_model: App, end_user: EndUser, c_id): + def delete(self, app_model: App, end_user: EndUser, c_id: UUID): """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -224,7 +225,7 @@ class ConversationRenameApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - def post(self, app_model: App, end_user: EndUser, c_id): + def post(self, app_model: App, end_user: EndUser, c_id: UUID): """Rename a conversation or auto-generate a name.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -266,7 +267,7 @@ class ConversationVariablesApi(Resource): service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__], ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - def get(self, app_model: App, end_user: EndUser, c_id): + def get(self, app_model: App, end_user: EndUser, c_id: UUID): """List all variables for a conversation. Conversational variables are only available for chat applications. @@ -312,7 +313,7 @@ class ConversationVariableDetailApi(Resource): service_api_ns.models[ConversationVariableResponse.__name__], ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - def put(self, app_model: App, end_user: EndUser, c_id, variable_id): + def put(self, app_model: App, end_user: EndUser, c_id: UUID, variable_id: UUID): """Update a conversation variable's value. Allows updating the value of a specific conversation variable. @@ -323,13 +324,13 @@ class ConversationVariableDetailApi(Resource): raise NotChatAppError() conversation_id = str(c_id) - variable_id = str(variable_id) + variable_id_str = str(variable_id) payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {}) try: variable = ConversationService.update_conversation_variable( - app_model, conversation_id, variable_id, end_user, payload.value + app_model, conversation_id, variable_id_str, end_user, payload.value ) return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 0ff0ae5104..d26d4c09b8 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,4 +1,5 @@ import logging +from uuid import UUID from flask import request from flask_restx import Resource @@ -94,19 +95,19 @@ class MessageFeedbackApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def post(self, app_model: App, end_user: EndUser, message_id): + def post(self, app_model: App, end_user: EndUser, message_id: UUID): """Submit feedback for a message. Allows users to rate messages as like/dislike and provide optional feedback content. """ - message_id = str(message_id) + message_id_str = str(message_id) payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, - message_id=message_id, + message_id=message_id_str, user=end_user, rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, @@ -159,19 +160,19 @@ class MessageSuggestedApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) - def get(self, app_model: App, end_user: EndUser, message_id): + def get(self, app_model: App, end_user: EndUser, message_id: UUID): """Get suggested follow-up questions for a message. Returns AI-generated follow-up questions based on the message content. """ - message_id = str(message_id) + message_id_str = str(message_id) app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API + app_model=app_model, user=end_user, message_id=message_id_str, invoke_from=InvokeFrom.SERVICE_API ) except MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 4745ca1275..4bcf969701 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,4 +1,5 @@ from typing import Any, Literal +from uuid import UUID from flask import request from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator @@ -336,7 +337,7 @@ class DatasetApi(DatasetApiResource): "Dataset retrieved successfully", service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__], ) - def get(self, _, dataset_id): + def get(self, _, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -403,7 +404,7 @@ class DatasetApi(DatasetApiResource): service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__], ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def patch(self, _, dataset_id): + def patch(self, _, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -479,7 +480,7 @@ class DatasetApi(DatasetApiResource): } ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, _, dataset_id): + def delete(self, _, dataset_id: UUID): """ Deletes a dataset given its ID. @@ -534,7 +535,7 @@ class DocumentStatusApi(DatasetApiResource): 400: "Bad request - invalid action", } ) - def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): + def patch(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]): """ Batch update document status. diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 1b9b9d40db..c1d1e1f0a0 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -374,7 +374,7 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id): + def post(self, tenant_id, dataset_id: UUID): """Create document by upload file.""" dataset = db.session.scalar( select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) @@ -395,7 +395,6 @@ class DocumentAddByFileApi(DatasetApiResource): args["doc_language"] = "English" # get dataset info - dataset_id = str(dataset_id) tenant_id = str(tenant_id) indexing_technique = args.get("indexing_technique") or dataset.indexing_technique @@ -586,17 +585,17 @@ class DocumentListApi(DatasetApiResource): 404: "Dataset not found", } ) - def get(self, tenant_id, dataset_id): - dataset_id = str(dataset_id) + def get(self, tenant_id, dataset_id: UUID): + dataset_id_str = str(dataset_id) tenant_id = str(tenant_id) query_params = DocumentListQuery.model_validate(request.args.to_dict()) dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") - query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id) + query = select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == tenant_id) if query_params.status: query = DocumentService.apply_display_status_filter(query, query_params.status) @@ -646,7 +645,7 @@ class DocumentBatchDownloadZipApi(DatasetApiResource): } ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id): + def post(self, tenant_id, dataset_id: UUID): payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {}) upload_files, download_name = DocumentService.prepare_document_batch_download_zip( @@ -681,18 +680,17 @@ class DocumentIndexingStatusApi(DatasetApiResource): 404: "Dataset or documents not found", } ) - def get(self, tenant_id, dataset_id, batch): - dataset_id = str(dataset_id) - batch = str(batch) + def get(self, tenant_id, dataset_id: UUID, batch: str): + dataset_id_str = str(dataset_id) tenant_id = str(tenant_id) # get dataset dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") # get documents - documents = DocumentService.get_batch_documents(dataset_id, batch) + documents = DocumentService.get_batch_documents(dataset_id_str, batch) if not documents: raise NotFound("Documents not found.") documents_status = [] @@ -757,7 +755,7 @@ class DocumentDownloadApi(DatasetApiResource): service_api_ns.models[UrlResponse.__name__], ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def get(self, tenant_id, dataset_id, document_id): + def get(self, tenant_id, dataset_id: UUID, document_id: UUID): dataset = self.get_dataset(str(dataset_id), str(tenant_id)) document = DocumentService.get_document(dataset.id, str(document_id)) @@ -785,13 +783,13 @@ class DocumentApi(DatasetApiResource): 404: "Document not found", } ) - def get(self, tenant_id, dataset_id, document_id): - dataset_id = str(dataset_id) - document_id = str(document_id) + def get(self, tenant_id, dataset_id: UUID, document_id: UUID): + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) - dataset = self.get_dataset(dataset_id, tenant_id) + dataset = self.get_dataset(dataset_id_str, tenant_id) - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) if not document: raise NotFound("Document not found.") @@ -808,15 +806,15 @@ class DocumentApi(DatasetApiResource): has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True if has_summary_index and document.need_summary is True: summary_index_status = SummaryIndexService.get_document_summary_index_status( - document_id=document_id, - dataset_id=dataset_id, + document_id=document_id_str, + dataset_id=dataset_id_str, tenant_id=tenant_id, ) if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": - dataset_process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id_str) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { @@ -851,7 +849,7 @@ class DocumentApi(DatasetApiResource): "need_summary": document.need_summary if document.need_summary is not None else False, } else: - dataset_process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id_str) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { @@ -918,21 +916,21 @@ class DocumentApi(DatasetApiResource): } ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, document_id): + def delete(self, tenant_id, dataset_id: UUID, document_id: UUID): """Delete document.""" - document_id = str(document_id) - dataset_id = str(dataset_id) + document_id_str = str(document_id) + dataset_id_str = str(dataset_id) tenant_id = str(tenant_id) # get dataset info dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise ValueError("Dataset does not exist.") - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) # 404 if document not found if document is None: diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 97a70f5d0e..ba914c4dd4 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,3 +1,5 @@ +from uuid import UUID + from controllers.common.schema import register_schema_model from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload from controllers.service_api import service_api_ns @@ -20,7 +22,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): ) @service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__]) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id): + def post(self, tenant_id, dataset_id: UUID): """Perform hit testing on a dataset. Tests retrieval performance for the specified dataset. diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 58bdd0f611..293a77fc5e 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,4 +1,5 @@ from typing import Literal +from uuid import UUID from flask_login import current_user from werkzeug.exceptions import NotFound @@ -57,7 +58,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): 201, "Metadata created successfully", service_api_ns.models[DatasetMetadataResponse.__name__] ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id): + def post(self, tenant_id, dataset_id: UUID): """Create metadata for a dataset.""" metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {}) @@ -83,7 +84,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): @service_api_ns.response( 200, "Metadata retrieved successfully", service_api_ns.models[DatasetMetadataListResponse.__name__] ) - def get(self, tenant_id, dataset_id): + def get(self, tenant_id, dataset_id: UUID): """Get all metadata for a dataset.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -110,7 +111,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): 200, "Metadata updated successfully", service_api_ns.models[DatasetMetadataResponse.__name__] ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def patch(self, tenant_id, dataset_id, metadata_id): + def patch(self, tenant_id, dataset_id: UUID, metadata_id: UUID): """Update metadata name.""" payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {}) @@ -136,7 +137,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): ) @service_api_ns.response(204, "Metadata deleted successfully") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, metadata_id): + def delete(self, tenant_id, dataset_id: UUID, metadata_id: UUID): """Delete metadata.""" dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -164,7 +165,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): "Built-in fields retrieved successfully", service_api_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__], ) - def get(self, tenant_id, dataset_id): + def get(self, tenant_id, dataset_id: UUID): """Get all built-in metadata fields.""" built_in_fields = MetadataService.get_built_in_fields() return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200 @@ -186,7 +187,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): 200, "Action completed successfully", service_api_ns.models[DatasetMetadataActionResponse.__name__] ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]): + def post(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable"]): """Enable or disable built-in metadata field.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -221,7 +222,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): service_api_ns.models[DatasetMetadataActionResponse.__name__], ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id): + def post(self, tenant_id, dataset_id: UUID): """Update metadata for multiple documents.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 8f52988ef9..f08d08ab7d 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -140,7 +140,7 @@ class CompletionStopApi(WebApiResource): } ) @web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__]) - def post(self, app_model, end_user, task_id): + def post(self, app_model, end_user, task_id: str): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() @@ -226,7 +226,7 @@ class ChatStopApi(WebApiResource): } ) @web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__]) - def post(self, app_model, end_user, task_id): + def post(self, app_model, end_user, task_id: str): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index a99adb391f..00db29a606 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,4 +1,5 @@ from typing import Literal +from uuid import UUID from flask import request from pydantic import BaseModel, Field, TypeAdapter, field_validator @@ -126,7 +127,7 @@ class ConversationApi(WebApiResource): 500: "Internal Server Error", } ) - def delete(self, app_model, end_user, c_id): + def delete(self, app_model, end_user, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -165,7 +166,7 @@ class ConversationRenameApi(WebApiResource): 500: "Internal Server Error", } ) - def post(self, app_model, end_user, c_id): + def post(self, app_model, end_user, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -203,7 +204,7 @@ class ConversationPinApi(WebApiResource): } ) @web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__]) - def patch(self, app_model, end_user, c_id): + def patch(self, app_model, end_user, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -234,7 +235,7 @@ class ConversationUnPinApi(WebApiResource): } ) @web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__]) - def patch(self, app_model, end_user, c_id): + def patch(self, app_model, end_user, c_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index fc5e266c5c..cf0363b66e 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,5 +1,6 @@ import logging from typing import Literal +from uuid import UUID from flask import request from pydantic import BaseModel, Field, TypeAdapter @@ -132,15 +133,15 @@ class MessageFeedbackApi(WebApiResource): } ) @web_ns.response(200, "Feedback submitted successfully", web_ns.models[ResultResponse.__name__]) - def post(self, app_model, end_user, message_id): - message_id = str(message_id) + def post(self, app_model, end_user, message_id: UUID): + message_id_str = str(message_id) payload = MessageFeedbackPayload.model_validate(web_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, - message_id=message_id, + message_id=message_id_str, user=end_user, rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, @@ -166,11 +167,11 @@ class MessageMoreLikeThisApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user, message_id): + def get(self, app_model, end_user, message_id: UUID): if app_model.mode != "completion": raise NotCompletionAppError() - message_id = str(message_id) + message_id_str = str(message_id) raw_args = request.args.to_dict() query = MessageMoreLikeThisQuery.model_validate(raw_args) @@ -181,7 +182,7 @@ class MessageMoreLikeThisApi(WebApiResource): response = AppGenerateService.generate_more_like_this( app_model=app_model, user=end_user, - message_id=message_id, + message_id=message_id_str, invoke_from=InvokeFrom.WEB_APP, streaming=streaming, ) @@ -222,16 +223,16 @@ class MessageSuggestedQuestionApi(WebApiResource): 500: "Internal Server Error", } ) - def get(self, app_model, end_user, message_id): + def get(self, app_model, end_user, message_id: UUID): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - message_id = str(message_id) + message_id_str = str(message_id) try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP + app_model=app_model, user=end_user, message_id=message_id_str, invoke_from=InvokeFrom.WEB_APP ) # questions is a list of strings, not a list of Message objects except MessageNotExistsError: diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index e307367b64..766cfc6c60 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,3 +1,5 @@ +from uuid import UUID + from flask import request from pydantic import TypeAdapter from werkzeug.exceptions import NotFound @@ -104,12 +106,12 @@ class SavedMessageApi(WebApiResource): 500: "Internal Server Error", } ) - def delete(self, app_model, end_user, message_id): - message_id = str(message_id) + def delete(self, app_model, end_user, message_id: UUID): + message_id_str = str(message_id) if app_model.mode != "completion": raise NotCompletionAppError() - SavedMessageService.delete(app_model, end_user, message_id) + SavedMessageService.delete(app_model, end_user, message_id_str) return "", 204 diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index d555f4d965..3d5ba94f2b 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -3,7 +3,6 @@ from __future__ import annotations import hashlib import logging from collections.abc import Generator, Iterable, Sequence -from threading import Lock from typing import IO, Any, Literal, cast, overload, override from pydantic import ValidationError @@ -13,9 +12,9 @@ from configs import dify_config from core.llm_generator.output_parser.structured_output import ( invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper, ) -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient +from core.plugin.plugin_service import PluginService from extensions.ext_redis import redis_client from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -101,35 +100,36 @@ class _PluginStructuredOutputModelInstance: class PluginModelRuntime(ModelRuntime): - """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" + """Plugin-backed runtime adapter bound to tenant context and optional caller scope. + + Provider discovery goes through ``PluginService`` so the plugin lifecycle + methods and provider reads share one tenant-scoped cache owner. + """ tenant_id: str user_id: str | None client: PluginModelClient - _provider_entities: tuple[ProviderEntity, ...] | None - _provider_entities_lock: Lock + _plugin_service: type[PluginService] - def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None: + def __init__( + self, + tenant_id: str, + user_id: str | None, + client: PluginModelClient, + plugin_service: type[PluginService], + ) -> None: if client is None: raise ValueError("client is required.") + if plugin_service is None: + raise ValueError("plugin_service is required.") self.tenant_id = tenant_id self.user_id = user_id self.client = client - self._provider_entities = None - self._provider_entities_lock = Lock() + self._plugin_service = plugin_service @override def fetch_model_providers(self) -> Sequence[ProviderEntity]: - if self._provider_entities is not None: - return self._provider_entities - - with self._provider_entities_lock: - if self._provider_entities is None: - self._provider_entities = tuple( - self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id) - ) - - return self._provider_entities + return self._plugin_service.fetch_plugin_model_providers(tenant_id=self.tenant_id, client=self.client) @override def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: @@ -628,34 +628,6 @@ class PluginModelRuntime(ModelRuntime): text=text, ) - def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str: - """ - Expose a bare provider alias only for the canonical provider mapping. - - Multiple plugins can publish the same short provider slug. If every - provider entity keeps that slug in ``provider_name``, callers that still - resolve by short name become order-dependent. Restrict the alias to the - provider selected by ``ModelProviderID`` so legacy short-name lookups - remain deterministic while the runtime surface stays canonical. - """ - try: - canonical_provider_id = ModelProviderID(provider.provider) - except ValueError: - return "" - - if canonical_provider_id.plugin_id != provider.plugin_id: - return "" - if canonical_provider_id.provider_name != provider.provider: - return "" - - return provider.provider - - def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity: - declaration = provider.declaration.model_copy(deep=True) - declaration.provider = f"{provider.plugin_id}/{provider.provider}" - declaration.provider_name = self._get_provider_short_name_alias(provider) - return declaration - def _get_provider_schema(self, provider: str) -> ProviderEntity: providers = self.fetch_model_providers() provider_entity = next((item for item in providers if item.provider == provider), None) diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index fbe307ea60..32a4c7b89a 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from core.plugin.impl.model import PluginModelClient +from core.plugin.plugin_service import PluginService from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ProviderEntity from graphon.model_runtime.model_providers.base.ai_model import AIModel @@ -117,6 +118,7 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) - tenant_id=tenant_id, user_id=user_id, client=PluginModelClient(), + plugin_service=PluginService, ) diff --git a/api/services/plugin/plugin_service.py b/api/core/plugin/plugin_service.py similarity index 73% rename from api/services/plugin/plugin_service.py rename to api/core/plugin/plugin_service.py index 72271c55d8..a88ddb5f3d 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/core/plugin/plugin_service.py @@ -1,8 +1,17 @@ +"""Core plugin service and tenant-scoped plugin metadata cache ownership. + +This module owns plugin daemon management calls that are shared by API services +and core runtimes. Plugin model provider discovery is cached here, alongside +plugin install, uninstall, and upgrade invalidation, so all cache mutations for +plugin-owned provider metadata stay tenant-scoped and in one place. +""" + import logging from collections.abc import Mapping, Sequence from mimetypes import guess_type -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter, ValidationError +from redis import RedisError from sqlalchemy import delete, select, update from sqlalchemy.orm import Session from yarl import URL @@ -22,16 +31,20 @@ from core.plugin.entities.plugin import ( from core.plugin.entities.plugin_daemon import ( PluginDecodeResponse, PluginInstallTask, + PluginInstallTaskStatus, PluginListResponse, + PluginModelProviderEntity, PluginVerification, ) from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.debugging import PluginDebuggingClient +from core.plugin.impl.model import PluginModelClient from core.plugin.impl.plugin import PluginInstaller from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.provider_entities import ProviderEntity from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider -from models.provider_ids import GenericProviderID +from models.provider_ids import GenericProviderID, ModelProviderID from services.enterprise.plugin_manager_service import ( PluginManagerService, PreUninstallPluginRequest, @@ -40,6 +53,7 @@ from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import FeatureService, PluginInstallationScope logger = logging.getLogger(__name__) +_provider_entities_adapter: TypeAdapter[list[ProviderEntity]] = TypeAdapter(list[ProviderEntity]) class PluginService: @@ -53,6 +67,102 @@ class PluginService: REDIS_KEY_PREFIX = "plugin_service:latest_plugin:" REDIS_TTL = 60 * 5 # 5 minutes + PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:" + PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed) + + @classmethod + def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str: + return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}" + + @staticmethod + def _get_provider_short_name_alias(provider: PluginModelProviderEntity) -> str: + """ + Expose a bare provider alias only for the canonical provider mapping. + + Multiple plugins can publish the same short provider slug. If every + provider entity keeps that slug in ``provider_name``, callers that still + resolve by short name become order-dependent. Restrict the alias to the + provider selected by ``ModelProviderID`` so legacy short-name lookups + remain deterministic while the runtime surface stays canonical. + """ + try: + canonical_provider_id = ModelProviderID(provider.provider) + except ValueError: + return "" + + if canonical_provider_id.plugin_id != provider.plugin_id: + return "" + if canonical_provider_id.provider_name != provider.provider: + return "" + + return provider.provider + + @classmethod + def _to_provider_entity(cls, provider: PluginModelProviderEntity) -> ProviderEntity: + declaration = provider.declaration.model_copy(deep=True) + declaration.provider = f"{provider.plugin_id}/{provider.provider}" + declaration.provider_name = cls._get_provider_short_name_alias(provider) + return declaration + + @classmethod + def _load_cached_plugin_model_providers(cls, tenant_id: str) -> tuple[ProviderEntity, ...] | None: + cache_key = cls._get_plugin_model_providers_cache_key(tenant_id) + try: + cached_providers = redis_client.get(cache_key) + except (RedisError, RuntimeError): + logger.warning("Failed to read cached plugin model providers for tenant %s.", tenant_id, exc_info=True) + return None + + if not cached_providers: + return None + + try: + return tuple(_provider_entities_adapter.validate_json(cached_providers)) + except (TypeError, ValueError, ValidationError): + logger.warning( + "Invalid cached plugin model providers for tenant %s; deleting cache.", tenant_id, exc_info=True + ) + cls.invalidate_plugin_model_providers_cache(tenant_id) + return None + + @classmethod + def _store_cached_plugin_model_providers(cls, tenant_id: str, providers: Sequence[ProviderEntity]) -> None: + cache_key = cls._get_plugin_model_providers_cache_key(tenant_id) + try: + payload = _provider_entities_adapter.dump_json(list(providers)).decode("utf-8") + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL, payload) + except (RedisError, RuntimeError): + logger.warning("Failed to cache plugin model providers for tenant %s.", tenant_id, exc_info=True) + + @classmethod + def invalidate_plugin_model_providers_cache(cls, tenant_id: str) -> None: + """Delete the tenant-scoped plugin model provider list cache.""" + try: + redis_client.delete(cls._get_plugin_model_providers_cache_key(tenant_id)) + except (RedisError, RuntimeError): + logger.warning("Failed to invalidate plugin model providers cache for tenant %s.", tenant_id, exc_info=True) + + @classmethod + def fetch_plugin_model_providers( + cls, *, tenant_id: str, client: PluginModelClient | None = None + ) -> Sequence[ProviderEntity]: + """ + Fetch plugin model providers through the tenant-scoped plugin cache. + + Plugin daemon provider discovery and plugin lifecycle cache invalidation + are intentionally owned by this service so tenant isolation and cache + expiry are handled in one place. + """ + cached_providers = cls._load_cached_plugin_model_providers(tenant_id) + if cached_providers is not None: + return cached_providers + + model_client = client or PluginModelClient() + providers = tuple( + cls._to_provider_entity(provider) for provider in model_client.fetch_model_providers(tenant_id) + ) + cls._store_cached_plugin_model_providers(tenant_id, providers) + return providers @staticmethod def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]: @@ -248,12 +358,18 @@ class PluginService: Fetch plugin installation tasks """ manager = PluginInstaller() - return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size) + tasks = manager.fetch_plugin_installation_tasks(tenant_id, page, page_size) + if any(task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES for task in tasks): + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return tasks @staticmethod def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask: manager = PluginInstaller() - return manager.fetch_plugin_installation_task(tenant_id, task_id) + task = manager.fetch_plugin_installation_task(tenant_id, task_id) + if task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES: + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return task @staticmethod def delete_install_task(tenant_id: str, task_id: str) -> bool: @@ -315,7 +431,7 @@ class PluginService: # check if the plugin is available to install PluginService._check_plugin_installation_scope(response.verification) - return manager.upgrade_plugin( + result = manager.upgrade_plugin( tenant_id, original_plugin_unique_identifier, new_plugin_unique_identifier, @@ -324,6 +440,8 @@ class PluginService: "plugin_unique_identifier": new_plugin_unique_identifier, }, ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def upgrade_plugin_with_github( @@ -339,7 +457,7 @@ class PluginService: """ PluginService._check_marketplace_only_permission() manager = PluginInstaller() - return manager.upgrade_plugin( + result = manager.upgrade_plugin( tenant_id, original_plugin_unique_identifier, new_plugin_unique_identifier, @@ -350,6 +468,8 @@ class PluginService: "package": package, }, ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse: @@ -415,12 +535,14 @@ class PluginService: resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) PluginService._check_plugin_installation_scope(resp.verification) - return manager.install_from_identifiers( + result = manager.install_from_identifiers( tenant_id, plugin_unique_identifiers, PluginInstallationSource.Package, [{}], ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str): @@ -434,7 +556,7 @@ class PluginService: plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) PluginService._check_plugin_installation_scope(plugin_decode_response.verification) - return manager.install_from_identifiers( + result = manager.install_from_identifiers( tenant_id, [plugin_unique_identifier], PluginInstallationSource.Github, @@ -446,6 +568,8 @@ class PluginService: } ], ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration: @@ -513,12 +637,14 @@ class PluginService: actual_plugin_unique_identifiers.append(response.unique_identifier) metas.append({"plugin_unique_identifier": response.unique_identifier}) - return manager.install_from_identifiers( + result = manager.install_from_identifiers( tenant_id, actual_plugin_unique_identifiers, PluginInstallationSource.Marketplace, metas, ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: @@ -529,7 +655,10 @@ class PluginService: plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None) if not plugin: - return manager.uninstall(tenant_id, plugin_installation_id) + result = manager.uninstall(tenant_id, plugin_installation_id) + if result: + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result if dify_config.ENTERPRISE_ENABLED: PluginManagerService.try_pre_uninstall_plugin( @@ -559,37 +688,39 @@ class PluginService: if not credential_ids: logger.info("No credentials found for plugin: %s", plugin_id) - return manager.uninstall(tenant_id, plugin_installation_id) + else: + provider_ids = session.scalars( + select(Provider.id).where( + Provider.tenant_id == tenant_id, + Provider.provider_name.like(f"{plugin_id}/%"), + Provider.credential_id.in_(credential_ids), + ) + ).all() - provider_ids = session.scalars( - select(Provider.id).where( - Provider.tenant_id == tenant_id, - Provider.provider_name.like(f"{plugin_id}/%"), - Provider.credential_id.in_(credential_ids), + session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None)) + + for provider_id in provider_ids: + ProviderCredentialsCache( + tenant_id=tenant_id, + identity_id=provider_id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ).delete() + + session.execute( + delete(ProviderCredential).where( + ProviderCredential.id.in_(credential_ids), + ) ) - ).all() - session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None)) - - for provider_id in provider_ids: - ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=provider_id, - cache_type=ProviderCredentialsCacheType.PROVIDER, - ).delete() - - session.execute( - delete(ProviderCredential).where( - ProviderCredential.id.in_(credential_ids), + logger.info( + "Completed deleting credentials and cleaning provider associations for plugin: %s", + plugin_id, ) - ) - logger.info( - "Completed deleting credentials and cleaning provider associations for plugin: %s", - plugin_id, - ) - - return manager.uninstall(tenant_id, plugin_installation_id) + result = manager.uninstall(tenant_id, plugin_installation_id) + if result: + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]: diff --git a/api/core/trigger/provider.py b/api/core/trigger/provider.py index 10fa31fdfa..893526b5e0 100644 --- a/api/core/trigger/provider.py +++ b/api/core/trigger/provider.py @@ -16,6 +16,7 @@ from core.plugin.entities.request import ( TriggerSubscriptionResponse, ) from core.plugin.impl.trigger import PluginTriggerClient +from core.plugin.plugin_service import PluginService from core.trigger.entities.api_entities import EventApiEntity, TriggerProviderApiEntity from core.trigger.entities.entities import ( EventEntity, @@ -30,7 +31,6 @@ from core.trigger.entities.entities import ( ) from core.trigger.errors import TriggerProviderCredentialValidationError from models.provider_ids import TriggerProviderID -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 260881e49c..12f47f3d57 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,6 +1,6 @@ -from typing import Union +from typing import Any, Union -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from core.rag.entities import RerankingModelConfig, WeightedScoreConfig from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict @@ -101,3 +101,14 @@ class KnowledgeIndexNodeData(BaseNodeData): index_chunk_variable_selector: list[str] indexing_technique: str | None = None summary_index_setting: SummaryIndexSettingDict | None = None + + @field_validator("summary_index_setting", mode="before") + @classmethod + def normalize_summary_index_setting(cls, v: Any) -> Any: + """Treat dicts with enable=None (or missing enable) as None (#36233).""" + if v is None: + return None + if isinstance(v, dict): + if v.get("enable") is None: + return None + return v diff --git a/api/models/model.py b/api/models/model.py index f7f90465cf..3647fbf6f7 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -492,8 +492,8 @@ class App(Base): @property def deleted_tools(self) -> list[DeletedToolInfo]: + from core.plugin.plugin_service import PluginService from core.tools.tool_manager import ToolManager, ToolProviderType - from services.plugin.plugin_service import PluginService # get agent mode tools app_model_config = self.app_model_config diff --git a/api/models/types.py b/api/models/types.py index 23028220f6..092db63856 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,7 +1,7 @@ import enum import json import uuid -from typing import Any, cast +from typing import Any, cast, override import sqlalchemy as sa from pydantic import BaseModel @@ -18,6 +18,7 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]): impl = CHAR cache_ok = True + @override def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: if value is None: return value @@ -28,12 +29,14 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]): return value.hex return value + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) else: return dialect.type_descriptor(CHAR(36)) + @override def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: if value is None: return value @@ -44,11 +47,13 @@ class LongText(TypeDecorator[str | None]): impl = TEXT cache_ok = True + @override def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None: if value is None: return value return value + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(TEXT()) @@ -57,6 +62,7 @@ class LongText(TypeDecorator[str | None]): else: return dialect.type_descriptor(TEXT()) + @override def process_result_value(self, value: str | None, dialect: Dialect) -> str | None: if value is None: return value @@ -77,6 +83,7 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]): self._model_class = model_class super().__init__() + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(TEXT()) @@ -85,6 +92,7 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]): else: return dialect.type_descriptor(TEXT()) + @override def process_bind_param(self, value: T | dict[str, Any] | str | None, dialect: Dialect) -> str | None: if value is None: return None @@ -96,6 +104,7 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]): model = self._model_class.model_validate(value) return json.dumps(model.model_dump(mode="json"), ensure_ascii=False, sort_keys=True, separators=(",", ":")) + @override def process_result_value(self, value: str | None, dialect: Dialect) -> T | None: if value is None or value == "": return None @@ -106,11 +115,13 @@ class BinaryData(TypeDecorator[bytes | None]): impl = LargeBinary cache_ok = True + @override def process_bind_param(self, value: bytes | None, dialect: Dialect) -> bytes | None: if value is None: return value return value + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(BYTEA()) @@ -119,6 +130,7 @@ class BinaryData(TypeDecorator[bytes | None]): else: return dialect.type_descriptor(LargeBinary()) + @override def process_result_value(self, value: bytes | None, dialect: Dialect) -> bytes | None: if value is None: return value @@ -133,6 +145,7 @@ class AdjustedJSON(TypeDecorator[dict | list | None]): self.astext_type = astext_type super().__init__() + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": if self.astext_type: @@ -144,11 +157,13 @@ class AdjustedJSON(TypeDecorator[dict | list | None]): else: return dialect.type_descriptor(sa.JSON()) + @override def process_bind_param( self, value: dict[str, Any] | list[Any] | None, dialect: Dialect ) -> dict[str, Any] | list[Any] | None: return value + @override def process_result_value( self, value: dict[str, Any] | list[Any] | None, dialect: Dialect ) -> dict[str, Any] | list[Any] | None: @@ -173,6 +188,7 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]): # leave some rooms for future longer enum values. self._length = max(max_enum_value_len, 20) + @override def process_bind_param(self, value: T | str | None, dialect: Dialect) -> str | None: if value is None: return value @@ -182,9 +198,11 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]): self._enum_class(value) return value + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: return dialect.type_descriptor(VARCHAR(self._length)) + @override def process_result_value(self, value: str | None, dialect: Dialect) -> T | None: if value is None or value == "": return None @@ -197,6 +215,7 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]): return cast(T, value_of(value)) raise + @override def compare_values(self, x: T | None, y: T | None) -> bool: if x is None or y is None: return x is y diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 2245adb681..e3c6f122ab 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -14,13 +14,13 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler +from core.plugin.plugin_service import PluginService from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 8fa3c3d4ef..88b9eeefa1 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -22,6 +22,7 @@ from core.helper import marketplace from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus from core.plugin.impl.plugin import PluginInstaller +from core.plugin.plugin_service import PluginService from core.tools.entities.tool_entities import ToolProviderType from extensions.ext_database import db from models.account import Tenant @@ -29,7 +30,6 @@ from models.model import App, AppMode, AppModelConfig from models.provider_ids import ModelProviderID, ToolProviderID from models.tools import BuiltinToolProvider from models.workflow import Workflow -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) @@ -389,17 +389,19 @@ class PluginMigration: for plugin_id in batch_plugin_ids if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"] ] - manager.install_from_identifiers( - tenant_id, - batch_plugin_identifiers, - PluginInstallationSource.Marketplace, - metas=[ - { - "plugin_unique_identifier": identifier, - } - for identifier in batch_plugin_identifiers - ], - ) + if batch_plugin_identifiers: + manager.install_from_identifiers( + tenant_id, + batch_plugin_identifiers, + PluginInstallationSource.Marketplace, + metas=[ + { + "plugin_unique_identifier": identifier, + } + for identifier in batch_plugin_identifiers + ], + ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) with open(extracted_plugins) as f: """ @@ -595,6 +597,7 @@ class PluginMigration: for identifier in batch_plugin_identifiers ], ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) except Exception: # add to failed failed.extend(batch_plugin_identifiers) @@ -609,6 +612,7 @@ class PluginMigration: while not done: status = manager.fetch_plugin_installation_task(tenant_id, task_id) if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]: + PluginService.invalidate_plugin_model_providers_cache(tenant_id) for plugin in status.plugins: if plugin.status == PluginInstallTaskStatus.Success: success.append(reverse_map[plugin.plugin_unique_identifier]) diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 37ebffbeb4..99fd3f5628 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -78,9 +78,9 @@ class CheckDependenciesPendingData(BaseModel): class RagPipelineDslService: """Import, export, and inspect RAG pipeline DSL using the caller-owned session. - Controllers wrap this service in a SQLAlchemy transaction context, so methods must only flush interim changes when - generated IDs are needed. Committing inside the service would close the caller's transaction and break later work in - the same context manager. + Callers pass a plain ``Session`` (not wrapped in ``.begin()``) and are responsible for calling + ``session.commit()`` on success or ``session.rollback()`` on failure. Methods here only flush + when generated IDs are needed mid-operation; they never commit or rollback. """ def __init__(self, session: Session): diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index f95519fc9e..ca755d0b91 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -12,6 +12,7 @@ from sqlalchemy import select from configs import dify_config from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller +from core.plugin.plugin_service import PluginService from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -22,7 +23,6 @@ from models.model import UploadFile from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting from services.plugin.plugin_migration import PluginMigration -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 20de1f4058..6b7092e318 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -13,6 +13,7 @@ from core.helper.name_generator import generate_incremental_name from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.plugin_service import PluginService from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ( @@ -31,7 +32,6 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider_ids import ToolProviderID from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient -from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index effbbaad01..c25f120917 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -9,6 +9,7 @@ from configs import dify_config from core.helper.provider_cache import ToolProviderCredentialsCache from core.mcp.types import Tool as MCPTool from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity +from core.plugin.plugin_service import PluginService from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -27,7 +28,6 @@ from core.tools.utils.encryption import create_provider_encrypter, create_tool_p from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index b8a76e4945..855f88fb90 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -14,6 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler +from core.plugin.plugin_service import PluginService from core.tools.utils.system_encryption import decrypt_system_params from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, @@ -37,7 +38,6 @@ from models.trigger import ( TriggerSubscription, WorkflowPluginTrigger, ) -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 59e02ec9b9..8ff9d651a4 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -6,11 +6,11 @@ from typing import Any, TypedDict from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session +from core.plugin.plugin_service import PluginService from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog -from services.plugin.plugin_service import PluginService from services.workflow.entities import TriggerMetadata diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 48d1774ce3..ab54b9e72e 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -9,9 +9,9 @@ from celery import shared_task from core.plugin.entities.marketplace import MarketplacePluginSnapshot from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller +from core.plugin.plugin_service import PluginService from extensions.ext_redis import redis_client from models.account import TenantPluginAutoUpgradeStrategy -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py index 0fa4d9043b..66ff24f374 100644 --- a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py @@ -1,4 +1,4 @@ -"""Tests for services.plugin.plugin_service.PluginService. +"""Tests for core.plugin.plugin_service.PluginService. Covers: version caching with Redis, install permission/scope gates, icon URL construction, asset retrieval with MIME guessing, plugin @@ -17,11 +17,11 @@ from sqlalchemy.orm import Session from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin_daemon import PluginVerification +from core.plugin.plugin_service import PluginService from models import ProviderType from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import PluginInstallationScope -from services.plugin.plugin_service import PluginService def _make_features( @@ -35,8 +35,8 @@ def _make_features( class TestFetchLatestPluginVersion: - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.redis_client") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.redis_client") def test_returns_cached_version(self, mock_redis, mock_marketplace): cached_json = PluginService.LatestPluginCache( plugin_id="p1", @@ -53,8 +53,8 @@ class TestFetchLatestPluginVersion: assert result["p1"].version == "1.0.0" mock_marketplace.batch_fetch_plugin_manifests.assert_not_called() - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.redis_client") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.redis_client") def test_fetches_from_marketplace_on_cache_miss(self, mock_redis, mock_marketplace): mock_redis.get.return_value = None manifest = MagicMock() @@ -71,8 +71,8 @@ class TestFetchLatestPluginVersion: assert result["p1"].version == "2.0.0" mock_redis.setex.assert_called_once() - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.redis_client") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.redis_client") def test_returns_none_for_unknown_plugin(self, mock_redis, mock_marketplace): mock_redis.get.return_value = None mock_marketplace.batch_fetch_plugin_manifests.return_value = [] @@ -81,8 +81,8 @@ class TestFetchLatestPluginVersion: assert result["unknown"] is None - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.redis_client") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.redis_client") def test_handles_marketplace_exception_gracefully(self, mock_redis, mock_marketplace): mock_redis.get.return_value = None mock_marketplace.batch_fetch_plugin_manifests.side_effect = RuntimeError("network error") @@ -93,14 +93,14 @@ class TestFetchLatestPluginVersion: class TestCheckMarketplaceOnlyPermission: - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_raises_when_restricted(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_marketplace_only_permission() - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_passes_when_not_restricted(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False) @@ -108,7 +108,7 @@ class TestCheckMarketplaceOnlyPermission: class TestCheckPluginInstallationScope: - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_official_only_allows_langgenius(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) verification = MagicMock() @@ -116,14 +116,14 @@ class TestCheckPluginInstallationScope: PluginService._check_plugin_installation_scope(verification) # should not raise - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_official_only_rejects_third_party(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(None) - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_official_and_partners_allows_partner(self, mock_fs): mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS @@ -133,7 +133,7 @@ class TestCheckPluginInstallationScope: PluginService._check_plugin_installation_scope(verification) # should not raise - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_official_and_partners_rejects_none(self, mock_fs): mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS @@ -142,7 +142,7 @@ class TestCheckPluginInstallationScope: with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(None) - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_none_scope_always_raises(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE) verification = MagicMock() @@ -151,7 +151,7 @@ class TestCheckPluginInstallationScope: with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(verification) - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_all_scope_passes_any(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL) @@ -159,7 +159,7 @@ class TestCheckPluginInstallationScope: class TestGetPluginIconUrl: - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.dify_config") def test_constructs_url_with_params(self, mock_config): mock_config.CONSOLE_API_URL = "https://console.example.com" @@ -171,7 +171,7 @@ class TestGetPluginIconUrl: class TestGetAsset: - @patch("services.plugin.plugin_service.PluginAssetManager") + @patch("core.plugin.plugin_service.PluginAssetManager") def test_returns_bytes_and_guessed_mime(self, mock_asset_cls): mock_asset_cls.return_value.fetch_asset.return_value = b"" @@ -180,7 +180,7 @@ class TestGetAsset: assert data == b"" assert "svg" in mime - @patch("services.plugin.plugin_service.PluginAssetManager") + @patch("core.plugin.plugin_service.PluginAssetManager") def test_fallback_to_octet_stream_for_unknown(self, mock_asset_cls): mock_asset_cls.return_value.fetch_asset.return_value = b"\x00" @@ -190,13 +190,13 @@ class TestGetAsset: class TestIsPluginVerified: - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.PluginInstaller") def test_returns_true_when_verified(self, mock_installer_cls): mock_installer_cls.return_value.fetch_plugin_manifest.return_value.verified = True assert PluginService.is_plugin_verified("t1", "uid-1") is True - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.PluginInstaller") def test_returns_false_on_exception(self, mock_installer_cls): mock_installer_cls.return_value.fetch_plugin_manifest.side_effect = RuntimeError("not found") @@ -204,24 +204,24 @@ class TestIsPluginVerified: class TestUpgradePluginWithMarketplace: - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.dify_config") def test_raises_when_marketplace_disabled(self, mock_config): mock_config.MARKETPLACE_ENABLED = False with pytest.raises(ValueError, match="marketplace is not enabled"): PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.dify_config") def test_raises_when_same_identifier(self, mock_config): mock_config.MARKETPLACE_ENABLED = True with pytest.raises(ValueError, match="same plugin"): PluginService.upgrade_plugin_with_marketplace("t1", "same-uid", "same-uid") - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.dify_config") def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace): mock_config.MARKETPLACE_ENABLED = True mock_fs.get_system_features.return_value = _make_features() @@ -234,10 +234,10 @@ class TestUpgradePluginWithMarketplace: mock_marketplace.record_install_plugin_event.assert_called_once_with("new-uid") installer.upgrade_plugin.assert_called_once() - @patch("services.plugin.plugin_service.download_plugin_pkg") - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.download_plugin_pkg") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.dify_config") def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True mock_fs.get_system_features.return_value = _make_features() @@ -256,8 +256,8 @@ class TestUpgradePluginWithMarketplace: class TestUpgradePluginWithGithub: - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value @@ -271,8 +271,8 @@ class TestUpgradePluginWithGithub: class TestUploadPkg: - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): mock_fs.get_system_features.return_value = _make_features() upload_resp = MagicMock() @@ -285,17 +285,17 @@ class TestUploadPkg: class TestInstallFromMarketplacePkg: - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.dify_config") def test_raises_when_marketplace_disabled(self, mock_config): mock_config.MARKETPLACE_ENABLED = False with pytest.raises(ValueError, match="marketplace is not enabled"): PluginService.install_from_marketplace_pkg("t1", ["uid-1"]) - @patch("services.plugin.plugin_service.download_plugin_pkg") - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.download_plugin_pkg") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.dify_config") def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True mock_fs.get_system_features.return_value = _make_features() @@ -315,9 +315,9 @@ class TestInstallFromMarketplacePkg: call_args = installer.install_from_identifiers.call_args[0] assert call_args[1] == ["resolved-uid"] - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.dify_config") def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): mock_config.MARKETPLACE_ENABLED = True mock_fs.get_system_features.return_value = _make_features() @@ -336,7 +336,7 @@ class TestInstallFromMarketplacePkg: class TestUninstall: - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.PluginInstaller") def test_direct_uninstall_when_plugin_not_found(self, mock_installer_cls): installer = mock_installer_cls.return_value installer.list_plugins.return_value = [] @@ -347,7 +347,7 @@ class TestUninstall: assert result is True installer.uninstall.assert_called_once_with("t1", "install-1") - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.PluginInstaller") def test_cleans_credentials_when_plugin_found( self, mock_installer_cls, flask_app_with_containers: Flask, db_session_with_containers: Session ): @@ -389,7 +389,7 @@ class TestUninstall: installer.list_plugins.return_value = [plugin] installer.uninstall.return_value = True - with patch("services.plugin.plugin_service.dify_config") as mock_config: + with patch("core.plugin.plugin_service.dify_config") as mock_config: mock_config.ENTERPRISE_ENABLED = False result = PluginService.uninstall(tenant_id, "install-1") diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 2dc50cc720..a174f5d69f 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -6,6 +6,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.plugin.plugin_service import PluginService from core.tools.__base.tool import Tool from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject @@ -20,7 +21,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider -from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService @@ -31,7 +31,7 @@ class TestToolTransformService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with patch("services.tools.tools_transform_service.dify_config") as mock_dify_config: - with patch("services.plugin.plugin_service.dify_config", new=mock_dify_config): + with patch("core.plugin.plugin_service.dify_config", new=mock_dify_config): # Setup default mock returns mock_dify_config.CONSOLE_API_URL = "https://console.example.com" diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py index 52da674f06..c2dfe3be30 100644 --- a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -1,6 +1,7 @@ from unittest.mock import Mock, patch from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly +from core.plugin.plugin_service import PluginService def test_plugin_model_assembly_reuses_single_runtime_across_views(): @@ -34,3 +35,11 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views(): mock_provider_factory_cls.assert_called_once_with(runtime=runtime) mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) + + +def test_create_plugin_model_runtime_injects_plugin_service(): + from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime + + runtime = create_plugin_model_runtime(tenant_id="tenant-1", user_id="user-1") + + assert runtime._plugin_service is PluginService diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index b1ecaa4ead..f9abc7d02a 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -12,6 +12,7 @@ from core.plugin.impl import model_runtime as model_runtime_module from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.plugin.plugin_service import PluginService from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta, LLMUsage from graphon.model_runtime.entities.message_entities import AssistantPromptMessage @@ -19,6 +20,22 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFr from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +class _FakeRedis: + def __init__(self) -> None: + self._values: dict[str, str] = {} + self.setex_calls: list[tuple[str, int, str]] = [] + + def get(self, key: str) -> str | None: + return self._values.get(key) + + def setex(self, key: str, ttl: int, value: str) -> None: + self._values[key] = value + self.setex_calls.append((key, ttl, value)) + + def delete(self, key: str) -> None: + self._values.pop(key, None) + + def _build_model_schema() -> AIModelEntity: return AIModelEntity( model="gpt-4o-mini", @@ -29,6 +46,24 @@ def _build_model_schema() -> AIModelEntity: ) +def _build_plugin_model_provider(*, tenant_id: str, provider: str = "openai") -> PluginModelProviderEntity: + return PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider=provider, + tenant_id=tenant_id, + plugin_unique_identifier=f"langgenius/{provider}/{provider}", + plugin_id=f"langgenius/{provider}", + declaration=ProviderEntity( + provider=provider, + label=I18nObject(en_US=provider.title()), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + + class TestPluginModelRuntime: """Validate the adapter keeps plugin-specific routing out of the runtime port.""" @@ -51,7 +86,7 @@ class TestPluginModelRuntime: ), ) ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) providers = runtime.fetch_model_providers() @@ -95,7 +130,7 @@ class TestPluginModelRuntime: ), ), ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) providers = runtime.fetch_model_providers() @@ -122,7 +157,7 @@ class TestPluginModelRuntime: ), ) ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) providers = runtime.fetch_model_providers() @@ -131,7 +166,7 @@ class TestPluginModelRuntime: def test_validate_provider_credentials_resolves_plugin_fields(self) -> None: client = Mock(spec=PluginModelClient) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) runtime.validate_provider_credentials( provider="langgenius/openai/openai", @@ -173,7 +208,7 @@ class TestPluginModelRuntime: ), ] ) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) result = runtime.invoke_llm( provider="langgenius/openai/openai", @@ -209,7 +244,7 @@ class TestPluginModelRuntime: client = Mock(spec=PluginModelClient) stream_result = iter([]) client.invoke_llm.return_value = stream_result - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) result = runtime.invoke_llm( provider="langgenius/openai/openai", @@ -240,7 +275,9 @@ class TestPluginModelRuntime: def test_invoke_llm_rejects_per_call_user_override(self) -> None: client = Mock(spec=PluginModelClient) client.invoke_llm.return_value = sentinel.result - runtime = PluginModelRuntime(tenant_id="tenant", user_id="bound-user", client=client) + runtime = PluginModelRuntime( + tenant_id="tenant", user_id="bound-user", client=client, plugin_service=PluginService + ) with pytest.raises(TypeError, match="unexpected keyword argument 'user_id'"): runtime.invoke_llm( # type: ignore[call-arg] @@ -260,7 +297,7 @@ class TestPluginModelRuntime: def test_invoke_tts_uses_bound_runtime_user_when_runtime_is_unbound(self) -> None: client = Mock(spec=PluginModelClient) client.invoke_tts.return_value = iter([b"chunk"]) - runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client, plugin_service=PluginService) result = runtime.invoke_tts( provider="langgenius/openai/openai", @@ -282,15 +319,107 @@ class TestPluginModelRuntime: voice="alloy", ) - def test_fetch_model_providers_uses_bound_runtime_cache(self) -> None: + def test_fetch_model_providers_does_not_keep_bound_runtime_cache(self, monkeypatch: pytest.MonkeyPatch) -> None: client = Mock(spec=PluginModelClient) client.fetch_model_providers.return_value = [] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + from core.plugin import plugin_service as plugin_service_module + + monkeypatch.setattr( + plugin_service_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value=None), + delete=Mock(), + setex=Mock(), + ), + ) + monkeypatch.setattr(plugin_service_module.dify_config, "PLUGIN_MODEL_PROVIDERS_CACHE_TTL", 300) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) runtime.fetch_model_providers() runtime.fetch_model_providers() - client.fetch_model_providers.assert_called_once_with("tenant") + assert client.fetch_model_providers.call_count == 2 + + def test_fetch_model_providers_uses_tenant_ttl_cache_across_runtime_instances( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + redis = _FakeRedis() + from core.plugin import plugin_service as plugin_service_module + + monkeypatch.setattr(plugin_service_module, "redis_client", redis) + monkeypatch.setattr(plugin_service_module.dify_config, "PLUGIN_MODEL_PROVIDERS_CACHE_TTL", 300) + first_client = Mock(spec=PluginModelClient) + first_client.fetch_model_providers.return_value = [_build_plugin_model_provider(tenant_id="tenant")] + second_client = Mock(spec=PluginModelClient) + first_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-a", client=first_client, plugin_service=PluginService + ) + second_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-b", client=second_client, plugin_service=PluginService + ) + + first_providers = first_runtime.fetch_model_providers() + second_providers = second_runtime.fetch_model_providers() + + assert [provider.provider for provider in first_providers] == ["langgenius/openai/openai"] + assert [provider.provider for provider in second_providers] == ["langgenius/openai/openai"] + first_client.fetch_model_providers.assert_called_once_with("tenant") + second_client.fetch_model_providers.assert_not_called() + assert redis.setex_calls[0][1] == 300 + + def test_fetch_model_providers_cache_is_tenant_isolated(self, monkeypatch: pytest.MonkeyPatch) -> None: + redis = _FakeRedis() + from core.plugin import plugin_service as plugin_service_module + + monkeypatch.setattr(plugin_service_module, "redis_client", redis) + monkeypatch.setattr(plugin_service_module.dify_config, "PLUGIN_MODEL_PROVIDERS_CACHE_TTL", 300) + first_client = Mock(spec=PluginModelClient) + first_client.fetch_model_providers.return_value = [_build_plugin_model_provider(tenant_id="tenant-a")] + second_client = Mock(spec=PluginModelClient) + second_client.fetch_model_providers.return_value = [_build_plugin_model_provider(tenant_id="tenant-b")] + first_runtime = PluginModelRuntime( + tenant_id="tenant-a", user_id="user", client=first_client, plugin_service=PluginService + ) + second_runtime = PluginModelRuntime( + tenant_id="tenant-b", user_id="user", client=second_client, plugin_service=PluginService + ) + + first_providers = first_runtime.fetch_model_providers() + second_providers = second_runtime.fetch_model_providers() + + assert [provider.provider for provider in first_providers] == ["langgenius/openai/openai"] + assert [provider.provider for provider in second_providers] == ["langgenius/openai/openai"] + first_client.fetch_model_providers.assert_called_once_with("tenant-a") + second_client.fetch_model_providers.assert_called_once_with("tenant-b") + assert len(redis.setex_calls) == 2 + + def test_fetch_model_providers_delegates_cache_to_injected_plugin_service(self) -> None: + client = Mock(spec=PluginModelClient) + service_result = [ + ProviderEntity( + provider="langgenius/openai/openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + ] + fetch_plugin_model_providers = Mock(return_value=service_result) + + class TestPluginService(PluginService): + pass + + TestPluginService.fetch_plugin_model_providers = fetch_plugin_model_providers + + runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user", client=client, plugin_service=TestPluginService + ) + + result = runtime.fetch_model_providers() + + assert result is service_result + fetch_plugin_model_providers.assert_called_once_with(tenant_id="tenant", client=client) + client.fetch_model_providers.assert_not_called() def test_create_plugin_model_runtime_without_user_context() -> None: @@ -301,7 +430,17 @@ def test_create_plugin_model_runtime_without_user_context() -> None: def test_plugin_model_runtime_requires_client() -> None: with pytest.raises(ValueError, match="client is required"): - PluginModelRuntime(tenant_id="tenant", user_id="user", client=None) # type: ignore[arg-type] + PluginModelRuntime(tenant_id="tenant", user_id="user", client=None, plugin_service=PluginService) # type: ignore[arg-type] + + +def test_plugin_model_runtime_requires_plugin_service() -> None: + with pytest.raises(ValueError, match="plugin_service is required"): + PluginModelRuntime( + tenant_id="tenant", + user_id="user", + client=Mock(spec=PluginModelClient), + plugin_service=None, # type: ignore[arg-type] + ) def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: pytest.MonkeyPatch) -> None: @@ -317,7 +456,7 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: ), ) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) result = runtime.get_model_schema( provider="langgenius/openai/openai", model_type=ModelType.LLM, @@ -395,7 +534,7 @@ def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None: client = Mock(spec=PluginModelClient) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) schema = _build_model_schema() runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign] @@ -436,7 +575,7 @@ def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> Non def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None: client = Mock(spec=PluginModelClient) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign] with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"): @@ -468,7 +607,7 @@ def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytes ) monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_MODEL_SCHEMA_CACHE_TTL", 300) client.get_model_schema.return_value = schema - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) result = runtime.get_model_schema( provider="langgenius/openai/openai", @@ -494,7 +633,7 @@ def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytes def test_get_llm_num_tokens_returns_zero_when_plugin_counting_is_disabled(monkeypatch: pytest.MonkeyPatch) -> None: client = Mock(spec=PluginModelClient) monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) assert ( runtime.get_llm_num_tokens( @@ -533,7 +672,7 @@ def test_get_provider_icon_reads_requested_variant_and_detects_svg_mime(monkeypa ] fetch_asset = Mock(return_value=b"") monkeypatch.setattr(model_runtime_module.PluginAssetManager, "fetch_asset", fetch_asset) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) icon_bytes, mime_type = runtime.get_provider_icon( provider="langgenius/openai/openai", @@ -565,7 +704,7 @@ def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> N ), ) ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) with pytest.raises(ValueError, match="does not have small dark icon"): runtime.get_provider_icon( @@ -583,7 +722,9 @@ def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> N def test_get_schema_cache_key_is_stable_across_credential_order() -> None: - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient)) + runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) first = runtime._get_schema_cache_key( provider="langgenius/openai/openai", @@ -602,8 +743,12 @@ def test_get_schema_cache_key_is_stable_across_credential_order() -> None: def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: - first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) - second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient)) + first_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) + second_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) first = first_runtime._get_schema_cache_key( provider="langgenius/openai/openai", @@ -622,8 +767,12 @@ def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: - tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) - user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + tenant_runtime = PluginModelRuntime( + tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) + user_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) tenant_key = tenant_runtime._get_schema_cache_key( provider="langgenius/openai/openai", @@ -643,8 +792,12 @@ def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None: - tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) - empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient)) + tenant_runtime = PluginModelRuntime( + tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) + empty_user_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) tenant_key = tenant_runtime._get_schema_cache_key( provider="langgenius/openai/openai", @@ -683,7 +836,7 @@ def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() ), ) ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) assert runtime._get_provider_schema("openai").provider == "langgenius/openai/openai" diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index e3b8269e15..8850137eb9 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -11,6 +11,7 @@ This test suite covers: import json from datetime import UTC, datetime from decimal import Decimal +from types import SimpleNamespace from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -197,6 +198,55 @@ class TestAppModelValidation: # Assert assert result == AppMode.CHAT + def test_deleted_tools_checks_plugin_builtin_providers_through_core_plugin_service(self): + """Plugin-backed built-in tools are checked through core PluginService.""" + # Arrange + app = App( + tenant_id="tenant-1", + name="Test App", + mode=AppMode.CHAT, + enable_site=True, + enable_api=False, + created_by=str(uuid4()), + ) + app_model_config = AppModelConfig( + app_id=str(uuid4()), + agent_mode=json.dumps( + { + "enabled": True, + "strategy": "function_call", + "tools": [ + { + "provider_type": "builtin", + "provider_id": "langgenius/openai/openai", + "tool_name": "chat", + "tool_parameters": {}, + } + ], + "prompt": None, + } + ), + ) + session_context = MagicMock() + session_context.__enter__.return_value = MagicMock() + session_factory = SimpleNamespace(begin=MagicMock(return_value=session_context)) + + # Act + with ( + patch.object(App, "app_model_config", new_callable=lambda: property(lambda self: app_model_config)), + patch("models.model.db", SimpleNamespace(engine=object())), + patch("models.model.sessionmaker", return_value=session_factory), + patch("core.tools.tool_manager.ToolManager.get_hardcoded_provider", side_effect=Exception), + patch("core.plugin.plugin_service.PluginService.check_tools_existence", return_value=[False]) as exists, + ): + result = app.deleted_tools + + # Assert + assert result == [{"type": "builtin", "tool_name": "chat", "provider_id": "langgenius/openai/openai"}] + exists.assert_called_once() + assert exists.call_args.args[0] == "tenant-1" + assert [str(provider_id) for provider_id in exists.call_args.args[1]] == ["langgenius/openai/openai"] + class TestAppModelConfig: """Test suite for AppModelConfig model.""" diff --git a/api/tests/unit_tests/services/plugin/conftest.py b/api/tests/unit_tests/services/plugin/conftest.py index 9dc4fa0390..cb30058494 100644 --- a/api/tests/unit_tests/services/plugin/conftest.py +++ b/api/tests/unit_tests/services/plugin/conftest.py @@ -24,7 +24,7 @@ def make_features( def mock_installer(monkeypatch: pytest.MonkeyPatch): """Patch PluginInstaller at the service import site.""" mock = MagicMock() - monkeypatch.setattr("services.plugin.plugin_service.PluginInstaller", lambda: mock) + monkeypatch.setattr("core.plugin.plugin_service.PluginInstaller", lambda: mock) return mock @@ -34,6 +34,6 @@ def mock_features(): from unittest.mock import patch features = make_features() - with patch("services.plugin.plugin_service.FeatureService") as mock_fs: + with patch("core.plugin.plugin_service.FeatureService") as mock_fs: mock_fs.get_system_features.return_value = features yield features diff --git a/api/tests/unit_tests/services/plugin/test_plugin_migration.py b/api/tests/unit_tests/services/plugin/test_plugin_migration.py index 12b6ea23a1..8f730d4ed3 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_migration.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_migration.py @@ -61,6 +61,7 @@ class TestHandlePluginInstanceInstall: patch(f"{MIGRATION_MODULE}.dify_config") as mock_cfg, patch(f"{MIGRATION_MODULE}.marketplace") as mock_marketplace, patch(f"{MIGRATION_MODULE}.PluginInstaller") as mock_installer_cls, + patch(f"{MIGRATION_MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, ): mock_cfg.MARKETPLACE_ENABLED = True mock_marketplace.download_plugin_pkg.return_value = b"pkg_data" @@ -73,4 +74,31 @@ class TestHandlePluginInstanceInstall: ) mock_marketplace.download_plugin_pkg.assert_called_once() + invalidate_cache.assert_called_once_with("tenant1") assert "success" in result or "failed" in result + + def test_install_plugins_invalidates_cache_after_direct_tenant_install(self, tmp_path) -> None: + extracted_plugins = tmp_path / "plugins.jsonl" + output_file = tmp_path / "output.json" + extracted_plugins.write_text('{"tenant_id":"tenant1","plugins":["langgenius/openai"]}\n') + + with ( + patch( + f"{MIGRATION_MODULE}.PluginMigration.extract_unique_plugins", + return_value={ + "plugins": {"langgenius/openai": "langgenius/openai:1.0.0@abc"}, + "plugin_not_exist": [], + }, + ), + patch(f"{MIGRATION_MODULE}.PluginMigration.handle_plugin_instance_install", return_value={}), + patch(f"{MIGRATION_MODULE}.PluginInstaller") as mock_installer_cls, + patch(f"{MIGRATION_MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + mock_installer = MagicMock() + mock_installer.list_plugins.return_value = [] + mock_installer_cls.return_value = mock_installer + + PluginMigration.install_plugins(str(extracted_plugins), str(output_file), workers=1) + + mock_installer.install_from_identifiers.assert_called_once() + invalidate_cache.assert_called_once_with("tenant1") diff --git a/api/tests/unit_tests/services/plugin/test_plugin_service.py b/api/tests/unit_tests/services/plugin/test_plugin_service.py index 05bb3b65c0..42edd5582b 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_service.py @@ -1,6 +1,71 @@ -from unittest.mock import MagicMock, patch +import datetime +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch -MODULE = "services.plugin.plugin_service" +from pydantic import TypeAdapter +from redis import RedisError + +from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStatus, PluginModelProviderEntity +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + +MODULE = "core.plugin.plugin_service" + + +class _FakeSession: + def __init__(self) -> None: + self.execute = Mock() + self.scalars = Mock(return_value=SimpleNamespace(all=Mock(return_value=[]))) + + def __enter__(self) -> "_FakeSession": + return self + + def __exit__(self, exc_type, exc, traceback) -> None: + return None + + def begin(self) -> "_FakeSession": + return self + + +def _build_provider_entity(provider: str = "openai") -> ProviderEntity: + return ProviderEntity( + provider=f"langgenius/{provider}/{provider}", + label=I18nObject(en_US=provider.title()), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + + +def _build_plugin_model_provider(*, tenant_id: str = "tenant-1", provider: str = "openai") -> PluginModelProviderEntity: + return PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider=provider, + tenant_id=tenant_id, + plugin_unique_identifier=f"langgenius/{provider}/{provider}", + plugin_id=f"langgenius/{provider}", + declaration=ProviderEntity( + provider=provider, + label=I18nObject(en_US=provider.title()), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + + +def _build_install_task(*, task_id: str = "task-1", status: PluginInstallTaskStatus) -> PluginInstallTask: + now = datetime.datetime.now() + return PluginInstallTask( + id=task_id, + created_at=now, + updated_at=now, + status=status, + total_plugins=1, + completed_plugins=1 if status != PluginInstallTaskStatus.Pending else 0, + plugins=[], + ) class TestFetchLatestPluginVersion: @@ -14,7 +79,7 @@ class TestFetchLatestPluginVersion: mock_cfg.MARKETPLACE_ENABLED = False mock_redis.get.return_value = None # all cache misses - from services.plugin.plugin_service import PluginService + from core.plugin.plugin_service import PluginService result = PluginService.fetch_latest_plugin_version(["langgenius/openai", "langgenius/anthropic"]) @@ -40,7 +105,7 @@ class TestFetchLatestPluginVersion: mock_redis.get.return_value = None mock_marketplace.batch_fetch_plugin_manifests.return_value = [manifest] - from services.plugin.plugin_service import PluginService + from core.plugin.plugin_service import PluginService result = PluginService.fetch_latest_plugin_version(["langgenius/openai"]) @@ -48,3 +113,322 @@ class TestFetchLatestPluginVersion: mock_marketplace.batch_fetch_plugin_manifests.assert_called_once() assert result["langgenius/openai"] is not None assert result["langgenius/openai"].version == "1.0.0" + + +class TestPluginModelProviderCache: + def test_fetch_plugin_model_providers_returns_cached_provider_without_calling_daemon(self) -> None: + """A valid tenant cache entry is reused across runtime calls without plugin daemon access.""" + cached_provider = _build_provider_entity() + cached_payload = TypeAdapter(list[ProviderEntity]).dump_json([cached_provider]).decode("utf-8") + + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.get.return_value = cached_payload + + from core.plugin.plugin_service import PluginService + + client = Mock() + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + client.fetch_model_providers.assert_not_called() + redis_client.setex.assert_not_called() + + def test_fetch_plugin_model_providers_deletes_invalid_cache_and_refetches(self) -> None: + """Invalid cache payloads are tenant-scoped invalidated before falling back to the daemon.""" + with ( + patch(f"{MODULE}.redis_client") as redis_client, + patch(f"{MODULE}.dify_config") as mock_config, + ): + redis_client.get.return_value = "not-json" + mock_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL = 86400 + client = Mock() + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + cache_key = "plugin_model_providers:tenant_id:tenant-1" + redis_client.delete.assert_called_once_with(cache_key) + redis_client.setex.assert_called_once() + assert redis_client.setex.call_args.args[0] == cache_key + assert redis_client.setex.call_args.args[1] == 86400 + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + + def test_fetch_plugin_model_providers_refetches_when_cache_read_fails(self) -> None: + """Redis read failures do not block provider discovery for the tenant.""" + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.get.side_effect = RedisError("redis unavailable") + client = Mock() + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + client.fetch_model_providers.assert_called_once_with("tenant-1") + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + + def test_fetch_plugin_model_providers_returns_fresh_result_when_cache_write_fails(self) -> None: + """Redis write failures are non-fatal after fresh provider data has been fetched.""" + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.get.return_value = None + redis_client.setex.side_effect = RedisError("redis unavailable") + client = Mock() + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + client.fetch_model_providers.assert_called_once_with("tenant-1") + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + + def test_fetch_plugin_model_providers_creates_default_client_on_cache_miss(self) -> None: + """The service owns plugin daemon access when no runtime-provided client is injected.""" + with ( + patch(f"{MODULE}.redis_client") as redis_client, + patch(f"{MODULE}.PluginModelClient") as client_cls, + ): + redis_client.get.return_value = None + client = client_cls.return_value + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1") + + client_cls.assert_called_once_with() + client.fetch_model_providers.assert_called_once_with("tenant-1") + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + + def test_invalidate_plugin_model_providers_cache_uses_tenant_cache_key(self) -> None: + with patch(f"{MODULE}.redis_client") as redis_client: + from core.plugin.plugin_service import PluginService + + PluginService.invalidate_plugin_model_providers_cache("tenant-1") + + redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1") + + def test_invalidate_plugin_model_providers_cache_ignores_redis_delete_failure(self) -> None: + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.delete.side_effect = RedisError("redis unavailable") + + from core.plugin.plugin_service import PluginService + + PluginService.invalidate_plugin_model_providers_cache("tenant-1") + + redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1") + + +class TestPluginModelProviderCacheInvalidation: + def test_fetch_install_task_invalidates_model_provider_cache_when_finished(self) -> None: + """Finished plugin install tasks invalidate tenant provider cache.""" + task = _build_install_task(status=PluginInstallTaskStatus.Success) + + with ( + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer_cls.return_value.fetch_plugin_installation_task.return_value = task + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_install_task("tenant-1", "task-1") + + assert result is task + invalidate_cache.assert_called_once_with("tenant-1") + + def test_fetch_install_tasks_invalidates_model_provider_cache_for_finished_tasks(self) -> None: + """Finished tasks from task list polling also invalidate tenant provider cache.""" + task = _build_install_task(status=PluginInstallTaskStatus.Success) + + with ( + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer_cls.return_value.fetch_plugin_installation_tasks.return_value = [task] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_install_tasks("tenant-1", 1, 256) + + assert result == [task] + invalidate_cache.assert_called_once_with("tenant-1") + + def test_fetch_install_tasks_ignores_running_model_provider_cache_tasks(self) -> None: + """Running plugin install tasks do not invalidate provider cache until they reach a terminal state.""" + task = _build_install_task(status=PluginInstallTaskStatus.Running) + + with ( + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer_cls.return_value.fetch_plugin_installation_tasks.return_value = [task] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_install_tasks("tenant-1", 1, 256) + + assert result == [task] + invalidate_cache.assert_not_called() + + def test_upgrade_plugin_with_marketplace_invalidates_model_provider_cache_for_tenant(self) -> None: + """Marketplace upgrades invalidate only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.dify_config") as mock_config, + patch(f"{MODULE}.FeatureService") as feature_service, + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.marketplace") as marketplace, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + mock_config.MARKETPLACE_ENABLED = True + feature_service.get_system_features.return_value = SimpleNamespace( + plugin_installation_permission=SimpleNamespace(restrict_to_marketplace_only=False) + ) + installer = installer_cls.return_value + installer.fetch_plugin_manifest.return_value = MagicMock() + installer.upgrade_plugin.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.upgrade_plugin_with_marketplace("tenant-1", "old-uid", "new-uid") + + assert result == "task-id" + marketplace.record_install_plugin_event.assert_called_once_with("new-uid") + invalidate_cache.assert_called_once_with("tenant-1") + + def test_install_from_local_pkg_invalidates_model_provider_cache_for_tenant(self) -> None: + """Starting a plugin install invalidates only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.PluginService._check_marketplace_only_permission"), + patch(f"{MODULE}.PluginService._check_plugin_installation_scope"), + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer = installer_cls.return_value + decode_response = MagicMock() + decode_response.verification = None + installer.decode_plugin_from_identifier.return_value = decode_response + installer.install_from_identifiers.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.install_from_local_pkg("tenant-1", ["langgenius/openai:1.0.0"]) + + assert result == "task-id" + invalidate_cache.assert_called_once_with("tenant-1") + + def test_upgrade_plugin_with_github_invalidates_model_provider_cache_for_tenant(self) -> None: + """Starting a plugin upgrade invalidates only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.PluginService._check_marketplace_only_permission"), + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer = installer_cls.return_value + installer.upgrade_plugin.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.upgrade_plugin_with_github( + "tenant-1", "old-uid", "new-uid", "langgenius/openai", "1.0.0", "openai.difypkg" + ) + + assert result == "task-id" + invalidate_cache.assert_called_once_with("tenant-1") + + def test_install_from_github_invalidates_model_provider_cache_for_tenant(self) -> None: + """GitHub installs invalidate only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.PluginService._check_marketplace_only_permission"), + patch(f"{MODULE}.PluginService._check_plugin_installation_scope"), + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer = installer_cls.return_value + decode_response = MagicMock() + decode_response.verification = None + installer.decode_plugin_from_identifier.return_value = decode_response + installer.install_from_identifiers.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.install_from_github( + "tenant-1", "langgenius/openai:1.0.0", "langgenius/openai", "1.0.0", "openai.difypkg" + ) + + assert result == "task-id" + invalidate_cache.assert_called_once_with("tenant-1") + + def test_install_from_marketplace_pkg_invalidates_model_provider_cache_for_tenant(self) -> None: + """Marketplace package installs invalidate only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.dify_config") as mock_config, + patch(f"{MODULE}.FeatureService") as feature_service, + patch(f"{MODULE}.PluginService._check_plugin_installation_scope"), + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + mock_config.MARKETPLACE_ENABLED = True + feature_service.get_system_features.return_value = SimpleNamespace( + plugin_installation_permission=SimpleNamespace(restrict_to_marketplace_only=False) + ) + installer = installer_cls.return_value + installer.fetch_plugin_manifest.return_value = MagicMock() + decode_response = MagicMock() + decode_response.verification = None + installer.decode_plugin_from_identifier.return_value = decode_response + installer.install_from_identifiers.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.install_from_marketplace_pkg("tenant-1", ["langgenius/openai:1.0.0"]) + + assert result == "task-id" + invalidate_cache.assert_called_once_with("tenant-1") + + def test_uninstall_invalidates_model_provider_cache_for_tenant(self) -> None: + """Successful uninstall invalidates only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer = installer_cls.return_value + installer.list_plugins.return_value = [] + installer.uninstall.return_value = True + + from core.plugin.plugin_service import PluginService + + result = PluginService.uninstall("tenant-1", "installation-1") + + assert result is True + invalidate_cache.assert_called_once_with("tenant-1") + + def test_uninstall_existing_plugin_invalidates_cache_after_credential_cleanup(self) -> None: + """Successful uninstall with plugin metadata also invalidates the mutated tenant provider cache.""" + plugin = SimpleNamespace( + installation_id="installation-1", + plugin_id="langgenius/openai", + plugin_unique_identifier="langgenius/openai:1.0.0", + ) + session = _FakeSession() + with ( + patch(f"{MODULE}.db", SimpleNamespace(engine=object())), + patch(f"{MODULE}.dify_config") as mock_config, + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.Session", return_value=session), + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + mock_config.ENTERPRISE_ENABLED = False + installer = installer_cls.return_value + installer.list_plugins.return_value = [plugin] + installer.uninstall.return_value = True + + from core.plugin.plugin_service import PluginService + + result = PluginService.uninstall("tenant-1", "installation-1") + + assert result is True + installer.uninstall.assert_called_once_with("tenant-1", "installation-1") + invalidate_cache.assert_called_once_with("tenant-1") diff --git a/docker/envs/core-services/shared.env.example b/docker/envs/core-services/shared.env.example index 6bdd7218b6..aaab029807 100644 --- a/docker/envs/core-services/shared.env.example +++ b/docker/envs/core-services/shared.env.example @@ -60,6 +60,7 @@ SSRF_PROXY_HTTPS_URL=http://ssrf_proxy:3128 PGDATA=/var/lib/postgresql/data/pgdata PLUGIN_MAX_PACKAGE_SIZE=52428800 PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 +PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400 ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} LOG_LEVEL=INFO LOG_OUTPUT_FORMAT=text