refactor: use EnumText for Conversation/Message invoke_from and from_source (#33901)

This commit is contained in:
tmimmanuel
2026-03-23 08:03:35 +01:00
committed by GitHub
parent 6ecf89e262
commit 2b6f761dfe
11 changed files with 37 additions and 27 deletions

View File

@ -50,7 +50,7 @@ class BuiltinTool(Tool):
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
tool_name=self.entity.identity.name,
prompt_messages=prompt_messages,
)

View File

@ -38,7 +38,7 @@ class ToolLabelManager:
db.session.add(
ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type.value,
tool_type=controller.provider_type,
label_name=label,
)
)
@ -58,7 +58,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
stmt = select(ToolLabelBinding.label_name).where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
ToolLabelBinding.tool_type == controller.provider_type,
)
labels = db.session.scalars(stmt).all()

View File

@ -9,6 +9,7 @@ from decimal import Decimal
from typing import cast
from core.model_manager import ModelManager
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.model_runtime.entities.llm_entities import LLMResult
from dify_graph.model_runtime.entities.message_entities import PromptMessage
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
@ -78,7 +79,7 @@ class ModelInvocationUtils:
@staticmethod
def invoke(
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context

View File

@ -43,6 +43,7 @@ from .enums import (
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
TagType,
)
from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID
@ -2404,7 +2405,7 @@ class Tag(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(16), nullable=False)
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(

View File

@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from core.tools.entities.tool_entities import (
ApiProviderSchemaType,
ToolProviderType,
WorkflowToolParameterConfiguration,
)
from .base import TypeBase
from .engine import db
from .model import Account, App, Tenant
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
if TYPE_CHECKING:
from core.entities.mcp_provider import MCPProviderEntity
@ -208,7 +212,7 @@ class ToolLabelBinding(TypeBase):
# tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# label name
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
@ -386,7 +390,7 @@ class ToolModelInvoke(TypeBase):
# provider
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# tool name
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
# invoke parameters

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import TagType
from models.model import App, Tag, TagBinding
@ -83,7 +84,7 @@ class TagService:
raise ValueError("Tag name already exists")
tag = Tag(
name=args["name"],
type=args["type"],
type=TagType(args["type"]),
created_by=current_user.id,
tenant_id=current_user.current_tenant_id,
)

View File

@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset
from models.enums import DataSourceType
from models.enums import DataSourceType, TagType
from models.model import App, Tag, TagBinding
from services.tag_service import TagService
@ -547,7 +547,7 @@ class TestTagService:
assert result is not None
assert len(result) == 1
assert result[0].name == "python_tag"
assert result[0].type == "app"
assert result[0].type == TagType.APP
assert result[0].tenant_id == tenant.id
def test_get_tag_by_tag_name_no_matches(
@ -638,7 +638,7 @@ class TestTagService:
# Verify all tags are returned
for tag in result:
assert tag.type == "app"
assert tag.type == TagType.APP
assert tag.tenant_id == tenant.id
assert tag.id in [t.id for t in tags]

View File

@ -11,6 +11,7 @@ from controllers.console.tag.tags import (
TagListApi,
TagUpdateDeleteApi,
)
from models.enums import TagType
def unwrap(func):
@ -52,7 +53,7 @@ def tag():
tag = MagicMock()
tag.id = "tag-1"
tag.name = "test-tag"
tag.type = "knowledge"
tag.type = TagType.KNOWLEDGE
return tag

View File

@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import (
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.enums import TagType
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.tag_service import TagService
@ -277,7 +278,7 @@ class TestDatasetTagsApi:
mock_tag = Mock()
mock_tag.id = "tag_1"
mock_tag.name = "Test Tag"
mock_tag.type = "knowledge"
mock_tag.type = TagType.KNOWLEDGE
mock_tag.binding_count = "0" # Required for Pydantic validation - must be string
mock_tag_service.get_tags.return_value = [mock_tag]
@ -316,7 +317,7 @@ class TestDatasetTagsApi:
mock_tag = Mock()
mock_tag.id = "new_tag_1"
mock_tag.name = "New Tag"
mock_tag.type = "knowledge"
mock_tag.type = TagType.KNOWLEDGE
mock_tag_service.save_tags.return_value = mock_tag
mock_service_api_ns.payload = {"name": "New Tag"}
@ -378,7 +379,7 @@ class TestDatasetTagsApi:
mock_tag = Mock()
mock_tag.id = "tag_1"
mock_tag.name = "Updated Tag"
mock_tag.type = "knowledge"
mock_tag.type = TagType.KNOWLEDGE
mock_tag.binding_count = "5"
mock_tag_service.update_tags.return_value = mock_tag
mock_tag_service.get_tag_binding_count.return_value = 5
@ -866,7 +867,7 @@ class TestTagService:
mock_tag = Mock()
mock_tag.id = str(uuid.uuid4())
mock_tag.name = "New Tag"
mock_tag.type = "knowledge"
mock_tag.type = TagType.KNOWLEDGE
mock_save.return_value = mock_tag
result = TagService.save_tags({"name": "New Tag", "type": "knowledge"})

View File

@ -12,7 +12,7 @@ This test suite covers:
import json
from uuid import uuid4
from core.tools.entities.tool_entities import ApiProviderSchemaType
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
from models.tools import (
ApiToolProvider,
BuiltinToolProvider,
@ -631,7 +631,7 @@ class TestToolLabelBinding:
"""Test creating a tool label binding."""
# Arrange
tool_id = "google.search"
tool_type = "builtin"
tool_type = ToolProviderType.BUILT_IN
label_name = "search"
# Act
@ -655,7 +655,7 @@ class TestToolLabelBinding:
# Act
label_binding = ToolLabelBinding(
tool_id=tool_id,
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
label_name=label_name,
)
@ -667,7 +667,7 @@ class TestToolLabelBinding:
"""Test multiple labels can be bound to the same tool."""
# Arrange
tool_id = "google.search"
tool_type = "builtin"
tool_type = ToolProviderType.BUILT_IN
# Act
binding1 = ToolLabelBinding(
@ -688,7 +688,7 @@ class TestToolLabelBinding:
def test_tool_label_binding_different_tool_types(self):
"""Test label bindings for different tool types."""
# Arrange
tool_types = ["builtin", "api", "workflow"]
tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW]
# Act & Assert
for tool_type in tool_types:
@ -951,12 +951,12 @@ class TestToolProviderRelationships:
# Act
binding1 = ToolLabelBinding(
tool_id=tool_id,
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
label_name="search",
)
binding2 = ToolLabelBinding(
tool_id=tool_id,
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
label_name="web",
)

View File

@ -75,6 +75,7 @@ import pytest
from werkzeug.exceptions import NotFound
from models.dataset import Dataset
from models.enums import TagType
from models.model import App, Tag, TagBinding
from services.tag_service import TagService
@ -102,7 +103,7 @@ class TagServiceTestDataFactory:
def create_tag_mock(
tag_id: str = "tag-123",
name: str = "Test Tag",
tag_type: str = "app",
tag_type: TagType = TagType.APP,
tenant_id: str = "tenant-123",
**kwargs,
) -> Mock:
@ -705,7 +706,7 @@ class TestTagServiceCRUD:
# Verify tag attributes
added_tag = mock_db_session.add.call_args[0][0]
assert added_tag.name == "New Tag", "Tag name should match"
assert added_tag.type == "app", "Tag type should match"
assert added_tag.type == TagType.APP, "Tag type should match"
assert added_tag.created_by == "user-123", "Created by should match current user"
assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"