mirror of
https://github.com/langgenius/dify.git
synced 2026-03-23 23:37:55 +08:00
refactor: use EnumText for Conversation/Message invoke_from and from_source (#33901)
This commit is contained in:
@ -50,7 +50,7 @@ class BuiltinTool(Tool):
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
@ -38,7 +38,7 @@ class ToolLabelManager:
|
||||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
tool_type=controller.provider_type,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
@ -58,7 +58,7 @@ class ToolLabelManager:
|
||||
raise ValueError("Unsupported tool type")
|
||||
stmt = select(ToolLabelBinding.label_name).where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
ToolLabelBinding.tool_type == controller.provider_type,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from decimal import Decimal
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
@ -78,7 +79,7 @@ class ModelInvocationUtils:
|
||||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
||||
@ -43,6 +43,7 @@ from .enums import (
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
TagType,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
@ -2404,7 +2405,7 @@ class Tag(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderSchemaType,
|
||||
ToolProviderType,
|
||||
WorkflowToolParameterConfiguration,
|
||||
)
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .model import Account, App, Tenant
|
||||
from .types import LongText, StringUUID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
@ -208,7 +212,7 @@ class ToolLabelBinding(TypeBase):
|
||||
# tool id
|
||||
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# tool type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# label name
|
||||
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
|
||||
@ -386,7 +390,7 @@ class ToolModelInvoke(TypeBase):
|
||||
# provider
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# tool name
|
||||
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
# invoke parameters
|
||||
|
||||
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
|
||||
|
||||
@ -83,7 +84,7 @@ class TagService:
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
name=args["name"],
|
||||
type=args["type"],
|
||||
type=TagType(args["type"]),
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from models.enums import DataSourceType
|
||||
from models.enums import DataSourceType, TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
from services.tag_service import TagService
|
||||
|
||||
@ -547,7 +547,7 @@ class TestTagService:
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "python_tag"
|
||||
assert result[0].type == "app"
|
||||
assert result[0].type == TagType.APP
|
||||
assert result[0].tenant_id == tenant.id
|
||||
|
||||
def test_get_tag_by_tag_name_no_matches(
|
||||
@ -638,7 +638,7 @@ class TestTagService:
|
||||
|
||||
# Verify all tags are returned
|
||||
for tag in result:
|
||||
assert tag.type == "app"
|
||||
assert tag.type == TagType.APP
|
||||
assert tag.tenant_id == tenant.id
|
||||
assert tag.id in [t.id for t in tags]
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from controllers.console.tag.tags import (
|
||||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
from models.enums import TagType
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@ -52,7 +53,7 @@ def tag():
|
||||
tag = MagicMock()
|
||||
tag.id = "tag-1"
|
||||
tag.name = "test-tag"
|
||||
tag.type = "knowledge"
|
||||
tag.type = TagType.KNOWLEDGE
|
||||
return tag
|
||||
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import (
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.enums import TagType
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.tag_service import TagService
|
||||
|
||||
@ -277,7 +278,7 @@ class TestDatasetTagsApi:
|
||||
mock_tag = Mock()
|
||||
mock_tag.id = "tag_1"
|
||||
mock_tag.name = "Test Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_tag.binding_count = "0" # Required for Pydantic validation - must be string
|
||||
mock_tag_service.get_tags.return_value = [mock_tag]
|
||||
|
||||
@ -316,7 +317,7 @@ class TestDatasetTagsApi:
|
||||
mock_tag = Mock()
|
||||
mock_tag.id = "new_tag_1"
|
||||
mock_tag.name = "New Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_tag_service.save_tags.return_value = mock_tag
|
||||
mock_service_api_ns.payload = {"name": "New Tag"}
|
||||
|
||||
@ -378,7 +379,7 @@ class TestDatasetTagsApi:
|
||||
mock_tag = Mock()
|
||||
mock_tag.id = "tag_1"
|
||||
mock_tag.name = "Updated Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_tag.binding_count = "5"
|
||||
mock_tag_service.update_tags.return_value = mock_tag
|
||||
mock_tag_service.get_tag_binding_count.return_value = 5
|
||||
@ -866,7 +867,7 @@ class TestTagService:
|
||||
mock_tag = Mock()
|
||||
mock_tag.id = str(uuid.uuid4())
|
||||
mock_tag.name = "New Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_save.return_value = mock_tag
|
||||
|
||||
result = TagService.save_tags({"name": "New Tag", "type": "knowledge"})
|
||||
|
||||
@ -12,7 +12,7 @@ This test suite covers:
|
||||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
|
||||
from models.tools import (
|
||||
ApiToolProvider,
|
||||
BuiltinToolProvider,
|
||||
@ -631,7 +631,7 @@ class TestToolLabelBinding:
|
||||
"""Test creating a tool label binding."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
tool_type = ToolProviderType.BUILT_IN
|
||||
label_name = "search"
|
||||
|
||||
# Act
|
||||
@ -655,7 +655,7 @@ class TestToolLabelBinding:
|
||||
# Act
|
||||
label_binding = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
label_name=label_name,
|
||||
)
|
||||
|
||||
@ -667,7 +667,7 @@ class TestToolLabelBinding:
|
||||
"""Test multiple labels can be bound to the same tool."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
tool_type = ToolProviderType.BUILT_IN
|
||||
|
||||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
@ -688,7 +688,7 @@ class TestToolLabelBinding:
|
||||
def test_tool_label_binding_different_tool_types(self):
|
||||
"""Test label bindings for different tool types."""
|
||||
# Arrange
|
||||
tool_types = ["builtin", "api", "workflow"]
|
||||
tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW]
|
||||
|
||||
# Act & Assert
|
||||
for tool_type in tool_types:
|
||||
@ -951,12 +951,12 @@ class TestToolProviderRelationships:
|
||||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
label_name="search",
|
||||
)
|
||||
binding2 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
label_name="web",
|
||||
)
|
||||
|
||||
|
||||
@ -75,6 +75,7 @@ import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models.dataset import Dataset
|
||||
from models.enums import TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
from services.tag_service import TagService
|
||||
|
||||
@ -102,7 +103,7 @@ class TagServiceTestDataFactory:
|
||||
def create_tag_mock(
|
||||
tag_id: str = "tag-123",
|
||||
name: str = "Test Tag",
|
||||
tag_type: str = "app",
|
||||
tag_type: TagType = TagType.APP,
|
||||
tenant_id: str = "tenant-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
@ -705,7 +706,7 @@ class TestTagServiceCRUD:
|
||||
# Verify tag attributes
|
||||
added_tag = mock_db_session.add.call_args[0][0]
|
||||
assert added_tag.name == "New Tag", "Tag name should match"
|
||||
assert added_tag.type == "app", "Tag type should match"
|
||||
assert added_tag.type == TagType.APP, "Tag type should match"
|
||||
assert added_tag.created_by == "user-123", "Created by should match current user"
|
||||
assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user