mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
Merge commit '92bde350' into sandboxed-agent-rebase
Made-with: Cursor # Conflicts: # api/controllers/console/app/workflow_draft_variable.py # api/core/agent/cot_agent_runner.py # api/core/agent/cot_chat_agent_runner.py # api/core/agent/cot_completion_agent_runner.py # api/core/agent/fc_agent_runner.py # api/core/app/apps/advanced_chat/app_generator.py # api/core/app/apps/advanced_chat/app_runner.py # api/core/app/apps/agent_chat/app_runner.py # api/core/app/apps/workflow/app_generator.py # api/core/app/apps/workflow/app_runner.py # api/core/app/entities/app_invoke_entities.py # api/core/app/entities/queue_entities.py # api/core/llm_generator/output_parser/structured_output.py # api/core/workflow/workflow_entry.py # api/dify_graph/context/__init__.py # api/dify_graph/entities/tool_entities.py # api/dify_graph/file/file_manager.py # api/dify_graph/graph_engine/response_coordinator/coordinator.py # api/dify_graph/graph_events/node.py # api/dify_graph/node_events/node.py # api/dify_graph/nodes/agent/agent_node.py # api/dify_graph/nodes/llm/entities.py # api/dify_graph/nodes/llm/llm_utils.py # api/dify_graph/nodes/llm/node.py # api/dify_graph/nodes/question_classifier/question_classifier_node.py # api/dify_graph/runtime/graph_runtime_state.py # api/dify_graph/variables/segments.py # api/factories/variable_factory.py # api/services/variable_truncator.py # api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py # api/uv.lock # web/app/components/app-sidebar/app-info.tsx # web/app/components/app-sidebar/app-sidebar-dropdown.tsx # web/app/components/app/create-app-modal/index.spec.tsx # web/app/components/apps/__tests__/list.spec.tsx # web/app/components/apps/app-card.tsx # web/app/components/apps/list.tsx # web/app/components/header/account-dropdown/compliance.tsx # web/app/components/header/account-dropdown/index.tsx # web/app/components/header/account-dropdown/support.tsx # web/app/components/workflow-app/components/workflow-onboarding-modal/index.tsx # web/app/components/workflow/panel/debug-and-preview/hooks.ts # web/contract/console/apps.ts # web/contract/router.ts # web/eslint-suppressions.json # web/next.config.ts # web/pnpm-lock.yaml
This commit is contained in:
@ -19,7 +19,15 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from dify_graph.file import file_manager
|
||||
from dify_graph.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
@ -29,17 +37,9 @@ from core.model_runtime.entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.workflow.file import file_manager
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
||||
|
||||
class CotAgentOutputParser:
|
||||
|
||||
@ -4,10 +4,10 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class ModelConfigConverter:
|
||||
|
||||
@ -2,9 +2,9 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
|
||||
@ -4,8 +4,8 @@ from core.app.app_config.entities import (
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import re
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
|
||||
@ -4,10 +4,10 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.file import FileUploadConfig
|
||||
from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity
|
||||
from dify_graph.file import FileUploadConfig
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from dify_graph.variables.input_entities import VariableEntity as WorkflowVariableEntity
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.workflow.file import FileUploadConfig
|
||||
from dify_graph.file import FileUploadConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from dify_graph.variables.input_entities import VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
|
||||
@ -32,19 +32,19 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity,
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from dify_graph.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
@ -26,16 +26,16 @@ from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.variables.variables import Variable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.enums import WorkflowType
|
||||
from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import VariableLoader
|
||||
from dify_graph.variables.variables import Variable
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.otel import WorkflowAppRunnerHandler, trace_span
|
||||
|
||||
@ -65,16 +65,16 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile
|
||||
@ -916,7 +916,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self._workflow_tenant_id,
|
||||
)
|
||||
form = form_repository.get_form(self._workflow_run_id, node_id)
|
||||
|
||||
@ -20,8 +20,8 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
@ -12,9 +12,9 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Union
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -4,21 +4,21 @@ from typing import TYPE_CHECKING, Any, Union, final
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file import File, FileUploadConfig
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.file import File, FileUploadConfig
|
||||
from dify_graph.repositories.draft_variable_repository import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from core.workflow.variables.input_entities import VariableEntityType
|
||||
from dify_graph.variables.input_entities import VariableEntityType
|
||||
from factories import file_factory
|
||||
from libs.orjson import orjson_dumps
|
||||
from models import Account, EndUser
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from dify_graph.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
|
||||
@ -20,7 +20,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -24,27 +24,27 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.file.models import File
|
||||
from dify_graph.file.models import File
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -19,8 +19,8 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account
|
||||
|
||||
@ -13,10 +13,10 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.file import File
|
||||
from dify_graph.file import File
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
|
||||
@ -49,21 +49,21 @@ from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import (
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.file import FILE_MODEL_IDENTITY, File
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
|
||||
@ -19,8 +19,8 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Message
|
||||
|
||||
@ -11,10 +11,10 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.file import File
|
||||
from dify_graph.file import File
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Message
|
||||
|
||||
|
||||
@ -33,13 +33,13 @@ from core.datasource.entities.datasource_entities import (
|
||||
)
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@ -8,23 +8,24 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
UserFrom,
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.graph_init_params import GraphInitParams
|
||||
from dify_graph.enums import WorkflowType
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import VariableLoader
|
||||
from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.enums import UserFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
@ -257,13 +258,15 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
# init graph
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@ -29,16 +29,16 @@ from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, Pau
|
||||
from core.app.layers.sandbox_layer import SandboxLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
@ -9,15 +9,15 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.enums import WorkflowType
|
||||
from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import VariableLoader
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.otel import WorkflowAppRunnerHandler, trace_span
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
@ -56,11 +56,11 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -29,12 +29,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
@ -60,14 +61,12 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
from dify_graph.graph_events.graph import GraphRunAbortedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from models.workflow import Workflow
|
||||
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
@ -119,13 +118,15 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
@ -267,13 +268,15 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@ -7,81 +7,77 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.workflow.file import File, FileUploadConfig
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.file import File, FileUploadConfig
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class UserFrom(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""
|
||||
Invoke From.
|
||||
"""
|
||||
|
||||
# SERVICE_API indicates that this invocation is from an API call to Dify app.
|
||||
#
|
||||
# Description of service api in Dify docs:
|
||||
# https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
|
||||
SERVICE_API = "service-api"
|
||||
|
||||
# WEB_APP indicates that this invocation is from
|
||||
# the web app of the workflow (or chatflow).
|
||||
#
|
||||
# Description of web app in Dify docs:
|
||||
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
|
||||
WEB_APP = "web-app"
|
||||
|
||||
# TRIGGER indicates that this invocation is from a trigger.
|
||||
# this is used for plugin trigger and webhook trigger.
|
||||
TRIGGER = "trigger"
|
||||
|
||||
# AGENT indicates that this invocation is from an agent.
|
||||
AGENT = "agent"
|
||||
|
||||
# EXPLORE indicates that this invocation is from
|
||||
# the workflow (or chatflow) explore page.
|
||||
EXPLORE = "explore"
|
||||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
# PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow.
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
|
||||
# VALIDATION indicates that this invocation is from validation.
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid invoke from value {value}")
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
"""
|
||||
Get source of invoke from.
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
:return: source
|
||||
"""
|
||||
if self == InvokeFrom.WEB_APP:
|
||||
return "web_app"
|
||||
elif self == InvokeFrom.DEBUGGER:
|
||||
return "dev"
|
||||
elif self == InvokeFrom.EXPLORE:
|
||||
return "explore_app"
|
||||
elif self == InvokeFrom.TRIGGER:
|
||||
return "trigger"
|
||||
elif self == InvokeFrom.SERVICE_API:
|
||||
return "api"
|
||||
|
||||
return "dev"
|
||||
class DifyRunContext(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
|
||||
def build_dify_run_context(
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
extra_context: Mapping[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build graph run_context with the reserved Dify runtime payload.
|
||||
|
||||
`extra_context` can carry user-defined context keys. The reserved `_dify`
|
||||
payload is always overwritten by this function to keep one canonical source.
|
||||
"""
|
||||
run_context = dict(extra_context) if extra_context else {}
|
||||
run_context[DIFY_RUN_CONTEXT_KEY] = DifyRunContext(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
return run_context
|
||||
|
||||
|
||||
class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
|
||||
@ -5,13 +5,13 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from dify_graph.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.nodes import NodeType
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
|
||||
@ -4,12 +4,12 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity
|
||||
from core.helper import moderation
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.variables import VariableBase
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.conversation_variable_updater import ConversationVariableUpdater
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from dify_graph.variables import VariableBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -6,9 +6,9 @@ from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events.base import GraphEngineEvent
|
||||
from dify_graph.graph_events.graph import GraphRunPausedEvent
|
||||
from models.model import AppMode
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import logging
|
||||
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
|
||||
from core.sandbox import Sandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events.base import GraphEngineEvent
|
||||
from dify_graph.graph_events.graph import GraphRunPausedEvent
|
||||
|
||||
|
||||
class SuspendLayer(GraphEngineLayer):
|
||||
|
||||
@ -4,9 +4,9 @@ from typing import ClassVar
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events.base import GraphEngineEvent
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
|
||||
|
||||
|
||||
@ -5,9 +5,9 @@ from typing import Any, ClassVar
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events.base import GraphEngineEvent
|
||||
from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
|
||||
from models.enums import WorkflowTriggerStatus
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity
|
||||
|
||||
@ -5,11 +5,11 @@ from typing import Any
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
|
||||
@ -6,7 +6,7 @@ from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
@ -16,8 +16,8 @@ from core.app.entities.task_entities import (
|
||||
PingStreamResponse,
|
||||
)
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from models.enums import MessageStatus
|
||||
from models.model import Message
|
||||
|
||||
|
||||
@ -47,19 +47,19 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from core.workflow.file.enums import FileTransferMethod
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from .node_factory import DifyNodeFactory
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
|
||||
__all__ = ["DifyNodeFactory"]
|
||||
|
||||
@ -5,13 +5,13 @@ from collections.abc import Generator
|
||||
from configs import dify_config
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from core.workflow.file.runtime import set_workflow_file_runtime
|
||||
from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from dify_graph.file.runtime import set_workflow_file_runtime
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
"""Production runtime wiring for ``core.workflow.file``."""
|
||||
"""Production runtime wiring for ``dify_graph.file``."""
|
||||
|
||||
@property
|
||||
def files_url(self) -> str:
|
||||
|
||||
@ -12,17 +12,17 @@ from typing_extensions import override
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
from dify_graph.graph_events.node import NodeRunSucceededEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -75,8 +75,9 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
return
|
||||
|
||||
try:
|
||||
dify_ctx = node.require_dify_context()
|
||||
deduct_llm_quota(
|
||||
tenant_id=node.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_instance=model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
@ -16,10 +16,10 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphNodeEventBase
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from extensions.otel.parser import (
|
||||
DefaultNodeOTelParser,
|
||||
LLMNodeOTelParser,
|
||||
|
||||
@ -17,17 +17,17 @@ from typing import Any, Union
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import (
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from dify_graph.enums import (
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
@ -42,9 +42,9 @@ from core.workflow.graph_events import (
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
|
||||
@ -15,8 +15,8 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class AudioTrunk:
|
||||
|
||||
@ -213,6 +213,6 @@ class DatasourceFileManager:
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
# from core.workflow.file.datasource_file_parser import datasource_file_manager
|
||||
# from dify_graph.file.datasource_file_parser import datasource_file_manager
|
||||
#
|
||||
# datasource_file_manager["manager"] = DatasourceFileManager
|
||||
|
||||
@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.file import File
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
@ -3,8 +3,8 @@ from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class DatasourceApiEntity(BaseModel):
|
||||
|
||||
@ -4,7 +4,7 @@ from mimetypes import guess_extension, guess_type
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from dify_graph.nodes.human_input.entities import FormInput, UserAction
|
||||
from models.execution_extra_content import ExecutionContentType
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
|
||||
@ -3,9 +3,9 @@ from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType, ProviderModel
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType, ProviderModel
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
|
||||
|
||||
|
||||
class ModelStatus(StrEnum):
|
||||
|
||||
@ -19,15 +19,15 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.provider import (
|
||||
|
||||
@ -11,8 +11,8 @@ from core.entities.parameter_entities import (
|
||||
ModelSelectorScope,
|
||||
ToolSelectorScope,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class ProviderQuotaType(StrEnum):
|
||||
|
||||
@ -13,7 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.workflow.nodes.code.entities import CodeLanguage
|
||||
from dify_graph.nodes.code.entities import CodeLanguage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
|
||||
@ -5,7 +5,7 @@ from base64 import b64encode
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.variables.utils import dumps_with_segments
|
||||
from dify_graph.variables.utils import dumps_with_segments
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
|
||||
@ -4,10 +4,10 @@ from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
||||
from configs import dify_config
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class HostingQuota(BaseModel):
|
||||
|
||||
@ -15,7 +15,6 @@ from configs import dify_config
|
||||
from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
@ -31,6 +30,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@ -34,15 +34,15 @@ from core.llm_generator.prompts import (
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
|
||||
@ -15,18 +15,18 @@ from core.llm_generator.prompts import (
|
||||
STRUCTURED_OUTPUT_TOOL_CALL_PROMPT,
|
||||
)
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ParameterRule
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ParameterRule
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
|
||||
@ -7,7 +7,7 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types as mcp_types
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from httpx_sse import connect_sse
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import ErrorData, JSONRPCError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
|
||||
@ -6,16 +6,16 @@ from sqlalchemy.orm import sessionmaker
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from dify_graph.file import file_manager
|
||||
from dify_graph.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.workflow.file import file_manager
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
|
||||
@ -7,20 +7,20 @@ from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import ProviderType
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
# Model Runtime
|
||||
|
||||
This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
|
||||
|
||||
- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
|
||||
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
|
||||
|
||||
## Features
|
||||
|
||||
- Supports capability invocation for 6 types of models
|
||||
|
||||
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
|
||||
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
|
||||
- `Rerank Model` - Segment Rerank capability
|
||||
- `Speech-to-text Model` - Speech to text capability
|
||||
- `Text-to-speech Model` - Text to speech capability
|
||||
- `Moderation` - Moderation capability
|
||||
|
||||
- Model provider display
|
||||
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
|
||||
|
||||
- Selectable model list display
|
||||
|
||||
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
|
||||
|
||||
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
|
||||
|
||||
- Provider/model credential authentication
|
||||
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
|
||||
|
||||
## Structure
|
||||
|
||||
Model Runtime is divided into three layers:
|
||||
|
||||
- The outermost layer is the factory method
|
||||
|
||||
It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
|
||||
|
||||
- The second layer is the provider layer
|
||||
|
||||
It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
|
||||
|
||||
- The bottom layer is the model layer
|
||||
|
||||
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).
|
||||
@ -1,64 +0,0 @@
|
||||
# Model Runtime
|
||||
|
||||
该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。
|
||||
|
||||
- 一方面将模型和上下游解耦,方便开发者对模型横向扩展,
|
||||
- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。
|
||||
|
||||
## 功能介绍
|
||||
|
||||
- 支持 6 种模型类型的能力调用
|
||||
|
||||
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
|
||||
- `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力
|
||||
- `Rerank Model` - 分段 Rerank 能力
|
||||
- `Speech-to-text Model` - 语音转文本能力
|
||||
- `Text-to-speech Model` - 文本转语音能力
|
||||
- `Moderation` - Moderation 能力
|
||||
|
||||
- 模型供应商展示
|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
|
||||
|
||||
- 供应商/模型凭据鉴权
|
||||
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
|
||||
|
||||
## 结构
|
||||
|
||||
Model Runtime 分三层:
|
||||
|
||||
- 最外层为工厂方法
|
||||
|
||||
提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。
|
||||
|
||||
- 第二层为供应商层
|
||||
|
||||
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
|
||||
|
||||
对于供应商/模型凭据,有两种情况
|
||||
|
||||
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||
|
||||
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
||||
|
||||
- 最底层为模型层
|
||||
|
||||
提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。
|
||||
|
||||
在这里我们需要先区分模型参数与模型凭据。
|
||||
|
||||
- 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。
|
||||
|
||||
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
|
||||
|
||||
## 文档
|
||||
|
||||
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。
|
||||
@ -1,151 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
class Callback(ABC):
|
||||
"""
|
||||
Base class for callbacks.
|
||||
Only for LLM.
|
||||
"""
|
||||
|
||||
raise_error: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def print_text(self, text: str, color: str | None = None, end: str = ""):
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
@ -1,170 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggingCallback(Callback):
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_before_invoke]\n", color="blue")
|
||||
self.print_text(f"Model: {model}\n", color="blue")
|
||||
self.print_text("Parameters:\n", color="blue")
|
||||
for key, value in model_parameters.items():
|
||||
self.print_text(f"\t{key}: {value}\n", color="blue")
|
||||
|
||||
if stop:
|
||||
self.print_text(f"\tstop: {stop}\n", color="blue")
|
||||
|
||||
if tools:
|
||||
self.print_text("\tTools:\n", color="blue")
|
||||
for tool in tools:
|
||||
self.print_text(f"\t\t{tool.name}\n", color="blue")
|
||||
|
||||
self.print_text(f"Stream: {stream}\n", color="blue")
|
||||
|
||||
if user:
|
||||
self.print_text(f"User: {user}\n", color="blue")
|
||||
|
||||
self.print_text("Prompt messages:\n", color="blue")
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.name:
|
||||
self.print_text(f"\tname: {prompt_message.name}\n", color="blue")
|
||||
|
||||
self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue")
|
||||
self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue")
|
||||
|
||||
if stream:
|
||||
self.print_text("\n[on_llm_new_chunk]")
|
||||
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
sys.stdout.write(cast(str, chunk.delta.message.content))
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_after_invoke]\n", color="yellow")
|
||||
self.print_text(f"Content: {result.message.content}\n", color="yellow")
|
||||
|
||||
if result.message.tool_calls:
|
||||
self.print_text("Tool calls:\n", color="yellow")
|
||||
for tool_call in result.message.tool_calls:
|
||||
self.print_text(f"\t{tool_call.id}\n", color="yellow")
|
||||
self.print_text(f"\t{tool_call.function.name}\n", color="yellow")
|
||||
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow")
|
||||
|
||||
self.print_text(f"Model: {result.model}\n", color="yellow")
|
||||
self.print_text(f"Usage: {result.usage}\n", color="yellow")
|
||||
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow")
|
||||
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_invoke_error]\n", color="red")
|
||||
logger.exception(ex)
|
||||
@ -1,43 +0,0 @@
|
||||
from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from .model_entities import ModelPropertyKey
|
||||
|
||||
__all__ = [
|
||||
"AssistantPromptMessage",
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
"ImagePromptMessageContent",
|
||||
"LLMMode",
|
||||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"MultiModalPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageContentType",
|
||||
"PromptMessageRole",
|
||||
"PromptMessageTool",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
"ToolPromptMessage",
|
||||
"UserPromptMessage",
|
||||
"VideoPromptMessageContent",
|
||||
]
|
||||
@ -1,16 +0,0 @@
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
zh_Hans: str | None = None
|
||||
en_US: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.zh_Hans:
|
||||
self.zh_Hans = self.en_US
|
||||
return self
|
||||
@ -1,130 +0,0 @@
|
||||
from core.model_runtime.entities.model_entities import DefaultParameterName
|
||||
|
||||
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
DefaultParameterName.TEMPERATURE: {
|
||||
"label": {
|
||||
"en_US": "Temperature",
|
||||
"zh_Hans": "温度",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls randomness. Lower temperature results in less random completions."
|
||||
" As the temperature approaches zero, the model will become deterministic and repetitive."
|
||||
" Higher temperature results in more random completions.",
|
||||
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。"
|
||||
"较高的温度会导致更多的随机完成。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.TOP_P: {
|
||||
"label": {
|
||||
"en_US": "Top P",
|
||||
"zh_Hans": "Top P",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
|
||||
" are considered.",
|
||||
"zh_Hans": "通过核心采样控制多样性:0.5 表示考虑了一半的所有可能性加权选项。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.TOP_K: {
|
||||
"label": {
|
||||
"en_US": "Top K",
|
||||
"zh_Hans": "Top K",
|
||||
},
|
||||
"type": "int",
|
||||
"help": {
|
||||
"en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.",
|
||||
"zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 50,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"precision": 0,
|
||||
},
|
||||
DefaultParameterName.PRESENCE_PENALTY: {
|
||||
"label": {
|
||||
"en_US": "Presence Penalty",
|
||||
"zh_Hans": "存在惩罚",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
|
||||
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.FREQUENCY_PENALTY: {
|
||||
"label": {
|
||||
"en_US": "Frequency Penalty",
|
||||
"zh_Hans": "频率惩罚",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
|
||||
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.MAX_TOKENS: {
|
||||
"label": {
|
||||
"en_US": "Max Tokens",
|
||||
"zh_Hans": "最大 Token 数",
|
||||
},
|
||||
"type": "int",
|
||||
"help": {
|
||||
"en_US": "Specifies the upper limit on the length of generated results."
|
||||
" If the generated results are truncated, you can increase this parameter.",
|
||||
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 64,
|
||||
"min": 1,
|
||||
"max": 2048,
|
||||
"precision": 0,
|
||||
},
|
||||
DefaultParameterName.RESPONSE_FORMAT: {
|
||||
"label": {
|
||||
"en_US": "Response Format",
|
||||
"zh_Hans": "回复格式",
|
||||
},
|
||||
"type": "string",
|
||||
"help": {
|
||||
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
|
||||
" such as JSON, XML, etc.",
|
||||
"zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等",
|
||||
},
|
||||
"required": False,
|
||||
"options": ["JSON", "XML"],
|
||||
},
|
||||
DefaultParameterName.JSON_SCHEMA: {
|
||||
"label": {
|
||||
"en_US": "JSON Schema",
|
||||
},
|
||||
"type": "text",
|
||||
"help": {
|
||||
"en_US": "Set a response json schema will ensure LLM to adhere it.",
|
||||
"zh_Hans": "设置返回的 json schema,llm 将按照它返回",
|
||||
},
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
@ -1,219 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import Any, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
|
||||
|
||||
|
||||
class LLMMode(StrEnum):
|
||||
"""
|
||||
Enum class for large language model mode.
|
||||
"""
|
||||
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class LLMUsageMetadata(TypedDict, total=False):
|
||||
"""
|
||||
TypedDict for LLM usage metadata.
|
||||
All fields are optional.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_unit_price: Union[float, str]
|
||||
completion_unit_price: Union[float, str]
|
||||
total_price: Union[float, str]
|
||||
currency: str
|
||||
prompt_price_unit: Union[float, str]
|
||||
completion_price_unit: Union[float, str]
|
||||
prompt_price: Union[float, str]
|
||||
completion_price: Union[float, str]
|
||||
latency: float
|
||||
time_to_first_token: float
|
||||
time_to_generate: float
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for llm usage.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
prompt_unit_price: Decimal
|
||||
prompt_price_unit: Decimal
|
||||
prompt_price: Decimal
|
||||
completion_tokens: int
|
||||
completion_unit_price: Decimal
|
||||
completion_price_unit: Decimal
|
||||
completion_price: Decimal
|
||||
total_tokens: int
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
time_to_first_token: float | None = None
|
||||
time_to_generate: float | None = None
|
||||
|
||||
@classmethod
|
||||
def empty_usage(cls):
|
||||
return cls(
|
||||
prompt_tokens=0,
|
||||
prompt_unit_price=Decimal("0.0"),
|
||||
prompt_price_unit=Decimal("0.0"),
|
||||
prompt_price=Decimal("0.0"),
|
||||
completion_tokens=0,
|
||||
completion_unit_price=Decimal("0.0"),
|
||||
completion_price_unit=Decimal("0.0"),
|
||||
completion_price=Decimal("0.0"),
|
||||
total_tokens=0,
|
||||
total_price=Decimal("0.0"),
|
||||
currency="USD",
|
||||
latency=0.0,
|
||||
time_to_first_token=None,
|
||||
time_to_generate=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
|
||||
"""
|
||||
Create LLMUsage instance from metadata dictionary with default values.
|
||||
|
||||
Args:
|
||||
metadata: TypedDict containing usage metadata
|
||||
|
||||
Returns:
|
||||
LLMUsage instance with values from metadata or defaults
|
||||
"""
|
||||
prompt_tokens = metadata.get("prompt_tokens", 0)
|
||||
completion_tokens = metadata.get("completion_tokens", 0)
|
||||
total_tokens = metadata.get("total_tokens", 0)
|
||||
|
||||
# If total_tokens is not provided but prompt and completion tokens are,
|
||||
# calculate total_tokens
|
||||
if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return cls(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
|
||||
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
|
||||
total_price=Decimal(str(metadata.get("total_price", 0))),
|
||||
currency=metadata.get("currency", "USD"),
|
||||
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
|
||||
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
|
||||
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
|
||||
completion_price=Decimal(str(metadata.get("completion_price", 0))),
|
||||
latency=metadata.get("latency", 0.0),
|
||||
time_to_first_token=metadata.get("time_to_first_token"),
|
||||
time_to_generate=metadata.get("time_to_generate"),
|
||||
)
|
||||
|
||||
def plus(self, other: LLMUsage) -> LLMUsage:
|
||||
"""
|
||||
Add two LLMUsage instances together.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
if self.total_tokens == 0:
|
||||
return other
|
||||
else:
|
||||
return LLMUsage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
prompt_unit_price=other.prompt_unit_price,
|
||||
prompt_price_unit=other.prompt_price_unit,
|
||||
prompt_price=self.prompt_price + other.prompt_price,
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
completion_unit_price=other.completion_unit_price,
|
||||
completion_price_unit=other.completion_price_unit,
|
||||
completion_price=self.completion_price + other.completion_price,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency,
|
||||
time_to_first_token=other.time_to_first_token,
|
||||
time_to_generate=other.time_to_generate,
|
||||
)
|
||||
|
||||
def __add__(self, other: LLMUsage) -> LLMUsage:
|
||||
"""
|
||||
Overload the + operator to add two LLMUsage instances.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
return self.plus(other)
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
Model class for llm result.
|
||||
"""
|
||||
|
||||
id: str | None = None
|
||||
model: str
|
||||
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage
|
||||
system_fingerprint: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class LLMStructuredOutput(BaseModel):
|
||||
"""
|
||||
Model class for llm structured output.
|
||||
"""
|
||||
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
"""
|
||||
|
||||
index: int
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
|
||||
model: str
|
||||
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||
system_fingerprint: str | None = None
|
||||
delta: LLMResultChunkDelta
|
||||
|
||||
|
||||
class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result chunk with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class NumTokensResult(PriceInfo):
|
||||
"""
|
||||
Model class for number of tokens result.
|
||||
"""
|
||||
|
||||
tokens: int
|
||||
@ -1,283 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
|
||||
SYSTEM = auto()
|
||||
USER = auto()
|
||||
ASSISTANT = auto()
|
||||
TOOL = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> PromptMessageRole:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid prompt message type value {value}")
|
||||
|
||||
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
Model class for prompt message tool.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
|
||||
|
||||
class PromptMessageFunction(BaseModel):
|
||||
"""
|
||||
Model class for prompt message function.
|
||||
"""
|
||||
|
||||
type: str = "function"
|
||||
function: PromptMessageTool
|
||||
|
||||
|
||||
class PromptMessageContentType(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
AUDIO = auto()
|
||||
VIDEO = auto()
|
||||
DOCUMENT = auto()
|
||||
|
||||
|
||||
class PromptMessageContent(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
|
||||
data: str
|
||||
|
||||
|
||||
class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
format: str = Field(default=..., description="the format of multi-modal file")
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
filename: str = Field(default="", description="the filename of multi-modal file")
|
||||
|
||||
# File reference for context restoration, format: "transfer_method:related_id" or "remote:url"
|
||||
file_ref: str | None = Field(default=None, description="Encoded file reference for restoration")
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
|
||||
class DETAIL(StrEnum):
|
||||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
|
||||
|
||||
|
||||
PromptMessageContentUnionTypes = Annotated[
|
||||
Union[
|
||||
TextPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
AudioPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
|
||||
PromptMessageContentType.TEXT: TextPromptMessageContent,
|
||||
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
|
||||
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
|
||||
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
|
||||
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: str | list[PromptMessageContentUnionTypes] | None = None
|
||||
name: str | None = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return not self.content
|
||||
|
||||
def get_text_content(self) -> str:
|
||||
"""
|
||||
Get text content from prompt message.
|
||||
|
||||
:return: Text content as string, empty string if no text content
|
||||
"""
|
||||
if isinstance(self.content, str):
|
||||
return self.content
|
||||
elif isinstance(self.content, list):
|
||||
text_parts = []
|
||||
for item in self.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return "".join(text_parts)
|
||||
else:
|
||||
return ""
|
||||
|
||||
@field_validator("content", mode="before")
|
||||
@classmethod
|
||||
def validate_content(cls, v):
|
||||
if isinstance(v, list):
|
||||
prompts = []
|
||||
for prompt in v:
|
||||
if isinstance(prompt, PromptMessageContent):
|
||||
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
|
||||
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
|
||||
elif isinstance(prompt, dict):
|
||||
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
|
||||
else:
|
||||
raise ValueError(f"invalid prompt message {prompt}")
|
||||
prompts.append(prompt)
|
||||
return prompts
|
||||
return v
|
||||
|
||||
@field_serializer("content")
|
||||
def serialize_content(
|
||||
self, content: Union[str, Sequence[PromptMessageContent]] | None
|
||||
) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None:
|
||||
if content is None or isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
return [item.model_dump() if hasattr(item, "model_dump") else item for item in content]
|
||||
return content
|
||||
|
||||
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for user prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.USER
|
||||
|
||||
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for assistant prompt message.
|
||||
"""
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call.
|
||||
"""
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
id: str
|
||||
type: str
|
||||
function: ToolCallFunction
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return super().is_empty() and not self.tool_calls
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for system prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||
|
||||
|
||||
class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for tool prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
# ToolPromptMessage is not empty if it has content OR has a tool_call_id
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
@ -1,242 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ModelType(StrEnum):
|
||||
"""
|
||||
Enum class for model type.
|
||||
"""
|
||||
|
||||
LLM = auto()
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = auto()
|
||||
SPEECH2TEXT = auto()
|
||||
MODERATION = auto()
|
||||
TTS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> ModelType:
|
||||
"""
|
||||
Get model type from origin model type.
|
||||
|
||||
:return: model type
|
||||
"""
|
||||
if origin_model_type in {"text-generation", cls.LLM}:
|
||||
return cls.LLM
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type in {"reranking", cls.RERANK}:
|
||||
return cls.RERANK
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type in {"tts", cls.TTS}:
|
||||
return cls.TTS
|
||||
elif origin_model_type == cls.MODERATION:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
raise ValueError(f"invalid origin model type {origin_model_type}")
|
||||
|
||||
def to_origin_model_type(self) -> str:
|
||||
"""
|
||||
Get origin model type from model type.
|
||||
|
||||
:return: origin model type
|
||||
"""
|
||||
if self == self.LLM:
|
||||
return "text-generation"
|
||||
elif self == self.TEXT_EMBEDDING:
|
||||
return "embeddings"
|
||||
elif self == self.RERANK:
|
||||
return "reranking"
|
||||
elif self == self.SPEECH2TEXT:
|
||||
return "speech2text"
|
||||
elif self == self.TTS:
|
||||
return "tts"
|
||||
elif self == self.MODERATION:
|
||||
return "moderation"
|
||||
else:
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
||||
class FetchFrom(StrEnum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
"""
|
||||
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class ModelFeature(StrEnum):
|
||||
"""
|
||||
Enum class for llm feature.
|
||||
"""
|
||||
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
STRUCTURED_OUTPUT = "structured-output"
|
||||
|
||||
|
||||
class DefaultParameterName(StrEnum):
|
||||
"""
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
|
||||
TEMPERATURE = auto()
|
||||
TOP_P = auto()
|
||||
TOP_K = auto()
|
||||
PRESENCE_PENALTY = auto()
|
||||
FREQUENCY_PENALTY = auto()
|
||||
MAX_TOKENS = auto()
|
||||
RESPONSE_FORMAT = auto()
|
||||
JSON_SCHEMA = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> DefaultParameterName:
|
||||
"""
|
||||
Get parameter name from value.
|
||||
|
||||
:param value: parameter value
|
||||
:return: parameter name
|
||||
"""
|
||||
for name in cls:
|
||||
if name.value == value:
|
||||
return name
|
||||
raise ValueError(f"invalid parameter name {value}")
|
||||
|
||||
|
||||
class ParameterType(StrEnum):
|
||||
"""
|
||||
Enum class for parameter type.
|
||||
"""
|
||||
|
||||
FLOAT = auto()
|
||||
INT = auto()
|
||||
STRING = auto()
|
||||
BOOLEAN = auto()
|
||||
TEXT = auto()
|
||||
|
||||
|
||||
class ModelPropertyKey(StrEnum):
|
||||
"""
|
||||
Enum class for model property key.
|
||||
"""
|
||||
|
||||
MODE = auto()
|
||||
CONTEXT_SIZE = auto()
|
||||
MAX_CHUNKS = auto()
|
||||
FILE_UPLOAD_LIMIT = auto()
|
||||
SUPPORTED_FILE_EXTENSIONS = auto()
|
||||
MAX_CHARACTERS_PER_CHUNK = auto()
|
||||
DEFAULT_VOICE = auto()
|
||||
VOICES = auto()
|
||||
WORD_LIMIT = auto()
|
||||
AUDIO_TYPE = auto()
|
||||
MAX_WORKERS = auto()
|
||||
|
||||
|
||||
class ProviderModel(BaseModel):
|
||||
"""
|
||||
Model class for provider model.
|
||||
"""
|
||||
|
||||
model: str
|
||||
label: I18nObject
|
||||
model_type: ModelType
|
||||
features: list[ModelFeature] | None = None
|
||||
fetch_from: FetchFrom
|
||||
model_properties: dict[ModelPropertyKey, Any]
|
||||
deprecated: bool = False
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def support_structure_output(self) -> bool:
|
||||
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
|
||||
|
||||
|
||||
class ParameterRule(BaseModel):
|
||||
"""
|
||||
Model class for parameter rule.
|
||||
"""
|
||||
|
||||
name: str
|
||||
use_template: str | None = None
|
||||
label: I18nObject
|
||||
type: ParameterType
|
||||
help: I18nObject | None = None
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
precision: int | None = None
|
||||
options: list[str] = []
|
||||
|
||||
|
||||
class PriceConfig(BaseModel):
|
||||
"""
|
||||
Model class for pricing info.
|
||||
"""
|
||||
|
||||
input: Decimal
|
||||
output: Decimal | None = None
|
||||
unit: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class AIModelEntity(ProviderModel):
|
||||
"""
|
||||
Model class for AI model.
|
||||
"""
|
||||
|
||||
parameter_rules: list[ParameterRule] = []
|
||||
pricing: PriceConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model(self):
|
||||
supported_schema_keys = ["json_schema"]
|
||||
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
||||
if not schema_key:
|
||||
return self
|
||||
if self.features is None:
|
||||
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||
else:
|
||||
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
|
||||
return self
|
||||
|
||||
|
||||
class ModelUsage(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class PriceType(StrEnum):
|
||||
"""
|
||||
Enum class for price type.
|
||||
"""
|
||||
|
||||
INPUT = auto()
|
||||
OUTPUT = auto()
|
||||
|
||||
|
||||
class PriceInfo(BaseModel):
|
||||
"""
|
||||
Model class for price info.
|
||||
"""
|
||||
|
||||
unit_price: Decimal
|
||||
unit: Decimal
|
||||
total_amount: Decimal
|
||||
currency: str
|
||||
@ -1,169 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
class ConfigurateMethod(StrEnum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class FormType(StrEnum):
|
||||
"""
|
||||
Enum class for form type.
|
||||
"""
|
||||
|
||||
TEXT_INPUT = "text-input"
|
||||
SECRET_INPUT = "secret-input"
|
||||
SELECT = auto()
|
||||
RADIO = auto()
|
||||
SWITCH = auto()
|
||||
|
||||
|
||||
class FormShowOnObject(BaseModel):
|
||||
"""
|
||||
Model class for form show on.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value: str
|
||||
|
||||
|
||||
class FormOption(BaseModel):
|
||||
"""
|
||||
Model class for form option.
|
||||
"""
|
||||
|
||||
label: I18nObject
|
||||
value: str
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.label:
|
||||
self.label = I18nObject(en_US=self.value)
|
||||
return self
|
||||
|
||||
|
||||
class CredentialFormSchema(BaseModel):
|
||||
"""
|
||||
Model class for credential form schema.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
label: I18nObject
|
||||
type: FormType
|
||||
required: bool = True
|
||||
default: str | None = None
|
||||
options: list[FormOption] | None = None
|
||||
placeholder: I18nObject | None = None
|
||||
max_length: int = 0
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
|
||||
class ProviderCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for provider credential schema.
|
||||
"""
|
||||
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
|
||||
class FieldModelSchema(BaseModel):
|
||||
label: I18nObject
|
||||
placeholder: I18nObject | None = None
|
||||
|
||||
|
||||
class ModelCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for model credential schema.
|
||||
"""
|
||||
|
||||
model: FieldModelSchema
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
|
||||
class SimpleProviderEntity(BaseModel):
|
||||
"""
|
||||
Simple model class for provider.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
|
||||
|
||||
class ProviderHelpEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider help.
|
||||
"""
|
||||
|
||||
title: I18nObject
|
||||
url: I18nObject
|
||||
|
||||
|
||||
class ProviderEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
models: list[AIModelEntity] = Field(default_factory=list)
|
||||
provider_credential_schema: ProviderCredentialSchema | None = None
|
||||
model_credential_schema: ModelCredentialSchema | None = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
# position from plugin _position.yaml
|
||||
position: dict[str, list[str]] | None = {}
|
||||
|
||||
@field_validator("models", mode="before")
|
||||
@classmethod
|
||||
def validate_models(cls, v):
|
||||
# returns EmptyList if v is empty
|
||||
if not v:
|
||||
return []
|
||||
return v
|
||||
|
||||
def to_simple_provider(self) -> SimpleProviderEntity:
|
||||
"""
|
||||
Convert to simple provider.
|
||||
|
||||
:return: simple provider
|
||||
"""
|
||||
return SimpleProviderEntity(
|
||||
provider=self.provider,
|
||||
label=self.label,
|
||||
icon_small=self.icon_small,
|
||||
supported_model_types=self.supported_model_types,
|
||||
models=self.models,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""
|
||||
Model class for provider config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
credentials: dict
|
||||
@ -1,20 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
"""
|
||||
|
||||
index: int
|
||||
text: str
|
||||
score: float
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
"""
|
||||
Model class for rerank result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
docs: list[RerankDocument]
|
||||
@ -1,39 +0,0 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelUsage
|
||||
|
||||
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
"""
|
||||
|
||||
tokens: int
|
||||
total_tokens: int
|
||||
unit_price: Decimal
|
||||
price_unit: Decimal
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
|
||||
|
||||
class EmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
class FileEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for file embedding result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
@ -1,40 +0,0 @@
|
||||
class InvokeError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
|
||||
class InvokeConnectionError(InvokeError):
|
||||
"""Raised when the Invoke returns connection error."""
|
||||
|
||||
description = "Connection Error"
|
||||
|
||||
|
||||
class InvokeServerUnavailableError(InvokeError):
|
||||
"""Raised when the Invoke returns server unavailable error."""
|
||||
|
||||
description = "Server Unavailable Error"
|
||||
|
||||
|
||||
class InvokeRateLimitError(InvokeError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
|
||||
|
||||
class InvokeAuthorizationError(InvokeError):
|
||||
"""Raised when the Invoke returns authorization error."""
|
||||
|
||||
description = "Incorrect model credentials provided, please check and try again. "
|
||||
|
||||
|
||||
class InvokeBadRequestError(InvokeError):
|
||||
"""Raised when the Invoke returns bad request."""
|
||||
|
||||
description = "Bad Request Error"
|
||||
@ -1,6 +0,0 @@
|
||||
class CredentialsValidateFailedError(ValueError):
|
||||
"""
|
||||
Credentials validate failed error
|
||||
"""
|
||||
|
||||
pass
|
||||
@ -1,3 +0,0 @@
|
||||
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
|
||||
|
||||
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]
|
||||
@ -1,18 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
|
||||
|
||||
|
||||
class PromptMessageMemory(Protocol):
|
||||
"""Port for loading memory as prompt messages."""
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Return historical prompt messages constrained by token/message limits."""
|
||||
...
|
||||
@ -1,286 +0,0 @@
|
||||
import decimal
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceInfo,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
"""
|
||||
Base class for all models.
|
||||
"""
|
||||
|
||||
tenant_id: str = Field(description="Tenant ID")
|
||||
model_type: ModelType = Field(description="Model type")
|
||||
plugin_id: str = Field(description="Plugin ID")
|
||||
provider_name: str = Field(description="Provider")
|
||||
plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider")
|
||||
started_at: float = Field(description="Invoke start time", default=0)
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [InvokeBadRequestError],
|
||||
PluginDaemonInnerError: [PluginDaemonInnerError],
|
||||
ValueError: [ValueError],
|
||||
}
|
||||
|
||||
def _transform_invoke_error(self, error: Exception) -> Exception:
|
||||
"""
|
||||
Transform invoke error to unified error
|
||||
|
||||
:param error: model invoke error
|
||||
:return: unified error
|
||||
"""
|
||||
for invoke_error, model_errors in self._invoke_error_mapping.items():
|
||||
if isinstance(error, tuple(model_errors)):
|
||||
if invoke_error == InvokeAuthorizationError:
|
||||
return InvokeAuthorizationError(
|
||||
description=(
|
||||
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
|
||||
)
|
||||
)
|
||||
elif isinstance(invoke_error, InvokeError):
|
||||
return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
|
||||
else:
|
||||
return error
|
||||
|
||||
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")
|
||||
|
||||
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
|
||||
"""
|
||||
Get price for given model and tokens
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param price_type: price type
|
||||
:param tokens: number of tokens
|
||||
:return: price info
|
||||
"""
|
||||
# get model schema
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
# get price info from predefined model schema
|
||||
price_config: PriceConfig | None = None
|
||||
if model_schema and model_schema.pricing:
|
||||
price_config = model_schema.pricing
|
||||
|
||||
# get unit price
|
||||
unit_price = None
|
||||
if price_config:
|
||||
if price_type == PriceType.INPUT:
|
||||
unit_price = price_config.input
|
||||
elif price_type == PriceType.OUTPUT and price_config.output is not None:
|
||||
unit_price = price_config.output
|
||||
|
||||
if unit_price is None:
|
||||
return PriceInfo(
|
||||
unit_price=decimal.Decimal("0.0"),
|
||||
unit=decimal.Decimal("0.0"),
|
||||
total_amount=decimal.Decimal("0.0"),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
# calculate total amount
|
||||
if not price_config:
|
||||
raise ValueError(f"Price config not found for model {model}")
|
||||
total_amount = tokens * unit_price * price_config.unit
|
||||
total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
return PriceInfo(
|
||||
unit_price=unit_price,
|
||||
unit=price_config.unit,
|
||||
total_amount=total_amount,
|
||||
currency=price_config.currency,
|
||||
)
|
||||
|
||||
def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema by model name and credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
schema = plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
|
||||
# get customizable model schema
|
||||
schema = self.get_customizable_model_schema(model, credentials)
|
||||
if not schema:
|
||||
return None
|
||||
|
||||
# fill in the template
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in schema.parameter_rules:
|
||||
if parameter_rule.use_template:
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
if not parameter_rule.max and "max" in default_parameter_rule:
|
||||
parameter_rule.max = default_parameter_rule["max"]
|
||||
if not parameter_rule.min and "min" in default_parameter_rule:
|
||||
parameter_rule.min = default_parameter_rule["min"]
|
||||
if not parameter_rule.default and "default" in default_parameter_rule:
|
||||
parameter_rule.default = default_parameter_rule["default"]
|
||||
if not parameter_rule.precision and "precision" in default_parameter_rule:
|
||||
parameter_rule.precision = default_parameter_rule["precision"]
|
||||
if not parameter_rule.required and "required" in default_parameter_rule:
|
||||
parameter_rule.required = default_parameter_rule["required"]
|
||||
if not parameter_rule.help and "help" in default_parameter_rule:
|
||||
parameter_rule.help = I18nObject(
|
||||
en_US=default_parameter_rule["help"]["en_US"],
|
||||
)
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.en_US
|
||||
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.zh_Hans
|
||||
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
|
||||
"zh_Hans", default_parameter_rule["help"]["en_US"]
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
new_parameter_rules.append(parameter_rule)
|
||||
|
||||
schema.parameter_rules = new_parameter_rules
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName):
|
||||
"""
|
||||
Get default parameter rule for given name
|
||||
|
||||
:param name: parameter name
|
||||
:return: parameter rule
|
||||
"""
|
||||
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
||||
|
||||
if not default_parameter_rule:
|
||||
raise Exception(f"Invalid model parameter rule name {name}")
|
||||
|
||||
return default_parameter_rule
|
||||
@ -1,668 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Iterator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageTool,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelType,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _gen_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||
|
||||
|
||||
def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None:
|
||||
if not callbacks:
|
||||
return
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
invoke(callback)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise
|
||||
logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e)
|
||||
|
||||
|
||||
def _get_or_create_tool_call(
|
||||
existing_tools_calls: list[AssistantPromptMessage.ToolCall],
|
||||
tool_call_id: str,
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Get or create a tool call by ID.
|
||||
|
||||
If `tool_call_id` is empty, returns the most recently created tool call.
|
||||
"""
|
||||
if not tool_call_id:
|
||||
if not existing_tools_calls:
|
||||
raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta")
|
||||
return existing_tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
existing_tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
|
||||
def _merge_tool_call_delta(
|
||||
tool_call: AssistantPromptMessage.ToolCall,
|
||||
delta: AssistantPromptMessage.ToolCall,
|
||||
) -> None:
|
||||
if delta.id:
|
||||
tool_call.id = delta.id
|
||||
if delta.type:
|
||||
tool_call.type = delta.type
|
||||
if delta.function.name:
|
||||
tool_call.function.name = delta.function.name
|
||||
if delta.function.arguments:
|
||||
tool_call.function.arguments += delta.function.arguments
|
||||
|
||||
|
||||
def _build_llm_result_from_chunks(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
chunks: Iterator[LLMResultChunk],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Build a single `LLMResult` by accumulating all returned chunks.
|
||||
|
||||
Some models only support streaming output (e.g. Qwen3 open-source edition)
|
||||
and the plugin side may still implement the response via a chunked stream,
|
||||
so all chunks must be consumed and concatenated into a single ``LLMResult``.
|
||||
|
||||
The ``usage`` is taken from the last chunk that carries it, which is the
|
||||
typical convention for streaming responses (the final chunk contains the
|
||||
aggregated token counts).
|
||||
"""
|
||||
content = ""
|
||||
content_list: list[PromptMessageContentUnionTypes] = []
|
||||
usage = LLMUsage.empty_usage()
|
||||
system_fingerprint: str | None = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
try:
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content += chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
if chunk.system_fingerprint:
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
except Exception:
|
||||
logger.exception("Error while consuming non-stream plugin chunk iterator.")
|
||||
raise
|
||||
finally:
|
||||
# Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections).
|
||||
close = getattr(chunks, "close", None)
|
||||
if callable(close):
|
||||
close()
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=content or content_list,
|
||||
tool_calls=tools_calls,
|
||||
),
|
||||
usage=usage,
|
||||
system_fingerprint=system_fingerprint,
|
||||
)
|
||||
|
||||
|
||||
def _invoke_llm_via_plugin(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
model_parameters: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_llm(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_non_stream_plugin_result(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
result: Union[LLMResult, Iterator[LLMResultChunk]],
|
||||
) -> LLMResult:
|
||||
if isinstance(result, LLMResult):
|
||||
return result
|
||||
return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result)
|
||||
|
||||
|
||||
def _increase_tool_call(
|
||||
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||
):
|
||||
"""
|
||||
Merge incremental tool call updates into existing tool calls.
|
||||
|
||||
:param new_tool_calls: List of new tool call deltas to be merged.
|
||||
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
|
||||
"""
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# generate ID for tool calls with function name but no ID to track them
|
||||
if new_tool_call.function.name and not new_tool_call.id:
|
||||
new_tool_call.id = _gen_tool_call_id()
|
||||
|
||||
tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id)
|
||||
_merge_tool_call_delta(tool_call, new_tool_call)
|
||||
|
||||
|
||||
class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
Model class for large language model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# validate and filter model parameters
|
||||
if model_parameters is None:
|
||||
model_parameters = {}
|
||||
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
if dify_config.DEBUG:
|
||||
callbacks.append(LoggingCallback())
|
||||
|
||||
# trigger before invoke callbacks
|
||||
self._trigger_before_invoke_callbacks(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
result = _invoke_llm_via_plugin(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model=model, prompt_messages=prompt_messages, result=result
|
||||
)
|
||||
except Exception as e:
|
||||
self._trigger_invoke_error_callbacks(
|
||||
model=model,
|
||||
ex=e,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# TODO
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
if stream and not isinstance(result, LLMResult):
|
||||
return self._invoke_result_generator(
|
||||
model=model,
|
||||
result=result,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
elif isinstance(result, LLMResult):
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=result,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# Following https://github.com/langgenius/dify/issues/17799,
|
||||
# we removed the prompt_messages from the chunk on the plugin daemon side.
|
||||
# To ensure compatibility, we add the prompt_messages back here.
|
||||
result.prompt_messages = prompt_messages
|
||||
return result
|
||||
raise NotImplementedError("unsupported invoke result type", type(result))
|
||||
|
||||
def _invoke_result_generator(
|
||||
self,
|
||||
model: str,
|
||||
result: Generator[LLMResultChunk, None, None],
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Invoke result generator
|
||||
|
||||
:param result: result generator
|
||||
:return: result generator
|
||||
"""
|
||||
callbacks = callbacks or []
|
||||
message_content: list[PromptMessageContentUnionTypes] = []
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
|
||||
def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
|
||||
if not content:
|
||||
return
|
||||
if isinstance(content, list):
|
||||
message_content.extend(content)
|
||||
return
|
||||
if isinstance(content, str):
|
||||
message_content.append(TextPromptMessageContent(data=content))
|
||||
return
|
||||
|
||||
try:
|
||||
for chunk in result:
|
||||
# Following https://github.com/langgenius/dify/issues/17799,
|
||||
# we removed the prompt_messages from the chunk on the plugin daemon side.
|
||||
# To ensure compatibility, we add the prompt_messages back here.
|
||||
chunk.prompt_messages = prompt_messages
|
||||
yield chunk
|
||||
|
||||
self._trigger_new_chunk_callbacks(
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
_update_message_content(chunk.delta.message.content)
|
||||
|
||||
real_model = chunk.model
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
|
||||
if chunk.system_fingerprint:
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=message_content)
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=LLMResult(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_message,
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
)
|
||||
return 0
|
||||
|
||||
def calc_response_usage(
|
||||
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_tokens: prompt tokens
|
||||
:param completion_tokens: completion tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get prompt price info
|
||||
prompt_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
# get completion price info
|
||||
completion_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=prompt_price_info.unit_price,
|
||||
prompt_price_unit=prompt_price_info.unit,
|
||||
prompt_price=prompt_price_info.total_amount,
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=completion_price_info.unit_price,
|
||||
completion_price_unit=completion_price_info.unit,
|
||||
completion_price=completion_price_info.total_amount,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
||||
currency=prompt_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def _trigger_before_invoke_callbacks(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger before invoke callbacks
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_before_invoke",
|
||||
invoke=lambda callback: callback.on_before_invoke(
|
||||
llm_instance=self,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def _trigger_new_chunk_callbacks(
|
||||
self,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger new chunk callbacks
|
||||
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_new_chunk",
|
||||
invoke=lambda callback: callback.on_new_chunk(
|
||||
llm_instance=self,
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def _trigger_after_invoke_callbacks(
|
||||
self,
|
||||
model: str,
|
||||
result: LLMResult,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger after invoke callbacks
|
||||
|
||||
:param model: model name
|
||||
:param result: result
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_after_invoke",
|
||||
invoke=lambda callback: callback.on_after_invoke(
|
||||
llm_instance=self,
|
||||
result=result,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def _trigger_invoke_error_callbacks(
|
||||
self,
|
||||
model: str,
|
||||
ex: Exception,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger invoke error callbacks
|
||||
|
||||
:param model: model name
|
||||
:param ex: exception
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_invoke_error",
|
||||
invoke=lambda callback: callback.on_invoke_error(
|
||||
llm_instance=self,
|
||||
ex=ex,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
@ -1,45 +0,0 @@
|
||||
import time
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class ModerationModel(AIModel):
|
||||
"""
|
||||
Model class for moderation model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
@ -1,92 +0,0 @@
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class RerankModel(AIModel):
|
||||
"""
|
||||
Base Model class for rerank model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.RERANK
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke multimodal rerank model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_multimodal_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
@ -1,43 +0,0 @@
|
||||
from typing import IO
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class Speech2TextModel(AIModel):
|
||||
"""
|
||||
Model class for speech2text model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
|
||||
"""
|
||||
Invoke speech to text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
@ -1,121 +0,0 @@
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class TextEmbeddingModel(AIModel):
|
||||
"""
|
||||
Model class for text embedding model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str] | None = None,
|
||||
multimodel_documents: list[dict] | None = None,
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param files: files to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
try:
|
||||
plugin_model_manager = PluginModelClient()
|
||||
if texts:
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
if multimodel_documents:
|
||||
return plugin_model_manager.invoke_multimodal_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
raise ValueError("No texts or files provided")
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def _get_context_size(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get context size for given embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: context size
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
|
||||
content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
return content_size
|
||||
|
||||
return 1000
|
||||
|
||||
def _get_max_chunks(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get max chunks for given embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: max chunks
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
||||
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
return max_chunks
|
||||
|
||||
return 1
|
||||
@ -1,53 +0,0 @@
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tokenizer: Any | None = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||
"""
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text) # type: ignore
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
||||
#
|
||||
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
# result = future.result()
|
||||
# return cast(int, result)
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder():
|
||||
global _tokenizer, _lock
|
||||
if _tokenizer is not None:
|
||||
return _tokenizer
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
# Try to use tiktoken to get the tokenizer because it is faster
|
||||
#
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
_tokenizer = tiktoken.get_encoding("gpt2")
|
||||
except Exception:
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
||||
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
|
||||
|
||||
return _tokenizer
|
||||
@ -1,79 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for TTS model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: str | None = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None):
|
||||
"""
|
||||
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
|
||||
|
||||
:param language: The language for which the voices are requested.
|
||||
:param model: The name of the TTS model.
|
||||
:param credentials: The credentials required to access the TTS model.
|
||||
:return: A list of voices supported by the TTS model.
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user