From 2b6f761dfef8d55ce95d55ca4d6749b2f7945d49 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 08:03:35 +0100 Subject: [PATCH] refactor: use EnumText for Conversation/Message invoke_from and from_source (#33901) --- api/core/tools/builtin_tool/tool.py | 2 +- api/core/tools/tool_label_manager.py | 4 ++-- api/core/tools/utils/model_invocation_utils.py | 3 ++- api/models/model.py | 3 ++- api/models/tools.py | 12 ++++++++---- api/services/tag_service.py | 3 ++- .../services/test_tag_service.py | 6 +++--- .../controllers/console/tag/test_tags.py | 3 ++- .../service_api/dataset/test_dataset.py | 9 +++++---- api/tests/unit_tests/models/test_tool_models.py | 14 +++++++------- api/tests/unit_tests/services/test_tag_service.py | 5 +++-- 11 files changed, 37 insertions(+), 27 deletions(-) diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 00f5931088..bcf58394ba 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -50,7 +50,7 @@ class BuiltinTool(Tool): return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id or "", - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, tool_name=self.entity.identity.name, prompt_messages=prompt_messages, ) diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 90d5a647e9..250dd91bfd 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -38,7 +38,7 @@ class ToolLabelManager: db.session.add( ToolLabelBinding( tool_id=provider_id, - tool_type=controller.provider_type.value, + tool_type=controller.provider_type, label_name=label, ) ) @@ -58,7 +58,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") stmt = select(ToolLabelBinding.label_name).where( ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, + ToolLabelBinding.tool_type == controller.provider_type, ) labels = db.session.scalars(stmt).all() diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8f958563bd..373bd1b1c8 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -9,6 +9,7 @@ from decimal import Decimal from typing import cast from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType from dify_graph.model_runtime.entities.llm_entities import LLMResult from dify_graph.model_runtime.entities.message_entities import PromptMessage from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -78,7 +79,7 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] + user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage] ) -> LLMResult: """ invoke model with parameters in user's own context diff --git a/api/models/model.py b/api/models/model.py index a08e43d128..b098966052 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -43,6 +43,7 @@ from .enums import ( MessageChainType, MessageFileBelongsTo, MessageStatus, + TagType, ) from .provider_ids import GenericProviderID from .types import EnumText, LongText, StringUUID @@ -2404,7 +2405,7 @@ class Tag(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - type: Mapped[str] = mapped_column(String(16), nullable=False) + type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/tools.py b/api/models/tools.py index c09f054e7d..01182af867 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle -from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolProviderType, + WorkflowToolParameterConfiguration, +) from .base import TypeBase from .engine import db from .model import Account, App, Tenant -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from core.entities.mcp_provider import MCPProviderEntity @@ -208,7 +212,7 @@ class ToolLabelBinding(TypeBase): # tool id tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # label name label_name: Mapped[str] = mapped_column(String(40), nullable=False) @@ -386,7 +390,7 @@ class ToolModelInvoke(TypeBase): # provider provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # tool name tool_name: Mapped[str] = mapped_column(String(128), nullable=False) # invoke parameters diff --git a/api/services/tag_service.py b/api/services/tag_service.py index bd3585acf4..70bf7f16f2 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from models.dataset import Dataset +from models.enums import TagType from models.model import App, Tag, TagBinding @@ -83,7 +84,7 @@ class TagService: raise ValueError("Tag name already exists") tag = Tag( name=args["name"], - type=args["type"], + type=TagType(args["type"]), created_by=current_user.id, tenant_id=current_user.current_tenant_id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index fa6e651529..1a72e3b6c2 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset -from models.enums import DataSourceType +from models.enums import DataSourceType, TagType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -547,7 +547,7 @@ class TestTagService: assert result is not None assert len(result) == 1 assert result[0].name == "python_tag" - assert result[0].type == "app" + assert result[0].type == TagType.APP assert result[0].tenant_id == tenant.id def test_get_tag_by_tag_name_no_matches( @@ -638,7 +638,7 @@ class TestTagService: # Verify all tags are returned for tag in result: - assert tag.type == "app" + assert tag.type == TagType.APP assert tag.tenant_id == tenant.id assert tag.id in [t.id for t in tags] diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 769edc8d1c..e89b89c8b1 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -11,6 +11,7 @@ from controllers.console.tag.tags import ( TagListApi, TagUpdateDeleteApi, ) +from models.enums import TagType def unwrap(func): @@ -52,7 +53,7 @@ def tag(): tag = MagicMock() tag.id = "tag-1" tag.name = "test-tag" - tag.type = "knowledge" + tag.type = TagType.KNOWLEDGE return tag diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py index 7cb2f1050c..8fe41cd19f 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import ( from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from models.account import Account from models.dataset import DatasetPermissionEnum +from models.enums import TagType from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.tag_service import TagService @@ -277,7 +278,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Test Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "0" # Required for Pydantic validation - must be string mock_tag_service.get_tags.return_value = [mock_tag] @@ -316,7 +317,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "new_tag_1" mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag_service.save_tags.return_value = mock_tag mock_service_api_ns.payload = {"name": "New Tag"} @@ -378,7 +379,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Updated Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "5" mock_tag_service.update_tags.return_value = mock_tag mock_tag_service.get_tag_binding_count.return_value = 5 @@ -866,7 +867,7 @@ class TestTagService: mock_tag = Mock() mock_tag.id = str(uuid.uuid4()) mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_save.return_value = mock_tag result = TagService.save_tags({"name": "New Tag", "type": "knowledge"}) diff --git a/api/tests/unit_tests/models/test_tool_models.py b/api/tests/unit_tests/models/test_tool_models.py index 1a75eb9a01..a6c2eae2c0 100644 --- a/api/tests/unit_tests/models/test_tool_models.py +++ b/api/tests/unit_tests/models/test_tool_models.py @@ -12,7 +12,7 @@ This test suite covers: import json from uuid import uuid4 -from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType from models.tools import ( ApiToolProvider, BuiltinToolProvider, @@ -631,7 +631,7 @@ class TestToolLabelBinding: """Test creating a tool label binding.""" # Arrange tool_id = "google.search" - tool_type = "builtin" + tool_type = ToolProviderType.BUILT_IN label_name = "search" # Act @@ -655,7 +655,7 @@ class TestToolLabelBinding: # Act label_binding = ToolLabelBinding( tool_id=tool_id, - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, label_name=label_name, ) @@ -667,7 +667,7 @@ class TestToolLabelBinding: """Test multiple labels can be bound to the same tool.""" # Arrange tool_id = "google.search" - tool_type = "builtin" + tool_type = ToolProviderType.BUILT_IN # Act binding1 = ToolLabelBinding( @@ -688,7 +688,7 @@ class TestToolLabelBinding: def test_tool_label_binding_different_tool_types(self): """Test label bindings for different tool types.""" # Arrange - tool_types = ["builtin", "api", "workflow"] + tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW] # Act & Assert for tool_type in tool_types: @@ -951,12 +951,12 @@ class TestToolProviderRelationships: # Act binding1 = ToolLabelBinding( tool_id=tool_id, - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, label_name="search", ) binding2 = ToolLabelBinding( tool_id=tool_id, - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, label_name="web", ) diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py index 264eac4d77..4d2d63e501 100644 --- a/api/tests/unit_tests/services/test_tag_service.py +++ b/api/tests/unit_tests/services/test_tag_service.py @@ -75,6 +75,7 @@ import pytest from werkzeug.exceptions import NotFound from models.dataset import Dataset +from models.enums import TagType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -102,7 +103,7 @@ class TagServiceTestDataFactory: def create_tag_mock( tag_id: str = "tag-123", name: str = "Test Tag", - tag_type: str = "app", + tag_type: TagType = TagType.APP, tenant_id: str = "tenant-123", **kwargs, ) -> Mock: @@ -705,7 +706,7 @@ class TestTagServiceCRUD: # Verify tag attributes added_tag = mock_db_session.add.call_args[0][0] assert added_tag.name == "New Tag", "Tag name should match" - assert added_tag.type == "app", "Tag type should match" + assert added_tag.type == TagType.APP, "Tag type should match" assert added_tag.created_by == "user-123", "Created by should match current user" assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"