refactor(api): move model_runtime into dify_graph (#32858)

This commit is contained in:
-LAN-
2026-03-02 20:15:32 +08:00
committed by GitHub
parent e985e73bdc
commit 4fd6b52808
253 changed files with 557 additions and 589 deletions

View File

@ -3,14 +3,14 @@ from __future__ import annotations
import base64
from collections.abc import Mapping
from core.model_runtime.entities import (
from dify_graph.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from . import helpers
from .enums import FileAttribute

View File

@ -5,7 +5,7 @@ from typing import Any
from pydantic import BaseModel, Field, model_validator
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
from . import helpers
from .constants import FILE_MODEL_IDENTITY

View File

@ -7,7 +7,6 @@ from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState
from dify_graph.graph import Graph
from dify_graph.graph_events import (
@ -30,6 +29,7 @@ from dify_graph.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.runtime import GraphRuntimeState
from ..domain.graph_execution import GraphExecution

View File

@ -0,0 +1,51 @@
# Model Runtime
This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
## Features
- Supports capability invocation for 6 types of models
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
- `Rerank Model` - Segment Rerank capability
- `Speech-to-text Model` - Speech to text capability
- `Text-to-speech Model` - Text to speech capability
- `Moderation` - Moderation capability
- Model provider display
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
- Selectable model list display
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
- Provider/model credential authentication
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
## Structure
Model Runtime is divided into three layers:
- The outermost layer is the factory method
It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
- The second layer is the provider layer
It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
- The bottom layer is the model layer
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
## Documentation
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).

View File

@ -0,0 +1,64 @@
# Model Runtime
该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。
- 一方面将模型和上下游解耦,方便开发者对模型横向扩展,
- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。
## 功能介绍
- 支持 6 种模型类型的能力调用
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
- `Text Embedding Model` - 文本 Embedding预计算 tokens 能力
- `Rerank Model` - 分段 Rerank 能力
- `Speech-to-text Model` - 语音转文本能力
- `Text-to-speech Model` - 文本转语音能力
- `Moderation` - Moderation 能力
- 模型供应商展示
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
- 可选择的模型列表展示
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
- 供应商/模型凭据鉴权
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
## 结构
Model Runtime 分三层:
- 最外层为工厂方法
提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。
- 第二层为供应商层
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
对于供应商/模型凭据,有两种情况
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
- 最底层为模型层
提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。
在这里我们需要先区分模型参数与模型凭据。
- 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
## 文档
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。

View File

View File

@ -0,0 +1,151 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
class Callback(ABC):
"""
Base class for callbacks.
Only for LLM.
"""
raise_error: bool = False
@abstractmethod
def on_before_invoke(
self,
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
):
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
raise NotImplementedError()
@abstractmethod
def on_new_chunk(
self,
llm_instance: AIModel,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
raise NotImplementedError()
@abstractmethod
def on_after_invoke(
self,
llm_instance: AIModel,
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
):
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
raise NotImplementedError()
@abstractmethod
def on_invoke_error(
self,
llm_instance: AIModel,
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
):
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
raise NotImplementedError()
def print_text(self, text: str, color: str | None = None, end: str = ""):
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(text_to_print, end=end)
def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"

View File

@ -0,0 +1,170 @@
import json
import logging
import sys
from collections.abc import Sequence
from typing import cast
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class LoggingCallback(Callback):
def on_before_invoke(
self,
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
):
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_before_invoke]\n", color="blue")
self.print_text(f"Model: {model}\n", color="blue")
self.print_text("Parameters:\n", color="blue")
for key, value in model_parameters.items():
self.print_text(f"\t{key}: {value}\n", color="blue")
if stop:
self.print_text(f"\tstop: {stop}\n", color="blue")
if tools:
self.print_text("\tTools:\n", color="blue")
for tool in tools:
self.print_text(f"\t\t{tool.name}\n", color="blue")
self.print_text(f"Stream: {stream}\n", color="blue")
if user:
self.print_text(f"User: {user}\n", color="blue")
self.print_text("Prompt messages:\n", color="blue")
for prompt_message in prompt_messages:
if prompt_message.name:
self.print_text(f"\tname: {prompt_message.name}\n", color="blue")
self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue")
self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue")
if stream:
self.print_text("\n[on_llm_new_chunk]")
def on_new_chunk(
self,
llm_instance: AIModel,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
sys.stdout.write(cast(str, chunk.delta.message.content))
sys.stdout.flush()
def on_after_invoke(
self,
llm_instance: AIModel,
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
):
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_after_invoke]\n", color="yellow")
self.print_text(f"Content: {result.message.content}\n", color="yellow")
if result.message.tool_calls:
self.print_text("Tool calls:\n", color="yellow")
for tool_call in result.message.tool_calls:
self.print_text(f"\t{tool_call.id}\n", color="yellow")
self.print_text(f"\t{tool_call.function.name}\n", color="yellow")
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow")
self.print_text(f"Model: {result.model}\n", color="yellow")
self.print_text(f"Usage: {result.usage}\n", color="yellow")
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow")
def on_invoke_error(
self,
llm_instance: AIModel,
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
):
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_invoke_error]\n", color="red")
logger.exception(ex)

View File

@ -0,0 +1,43 @@
from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from .message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from .model_entities import ModelPropertyKey
__all__ = [
"AssistantPromptMessage",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"ImagePromptMessageContent",
"LLMMode",
"LLMResult",
"LLMResultChunk",
"LLMResultChunkDelta",
"LLMUsage",
"ModelPropertyKey",
"MultiModalPromptMessageContent",
"PromptMessage",
"PromptMessageContent",
"PromptMessageContentType",
"PromptMessageRole",
"PromptMessageTool",
"SystemPromptMessage",
"TextPromptMessageContent",
"ToolPromptMessage",
"UserPromptMessage",
"VideoPromptMessageContent",
]

View File

@ -0,0 +1,16 @@
from pydantic import BaseModel, model_validator
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
zh_Hans: str | None = None
en_US: str
@model_validator(mode="after")
def _(self):
if not self.zh_Hans:
self.zh_Hans = self.en_US
return self

View File

@ -0,0 +1,130 @@
from dify_graph.model_runtime.entities.model_entities import DefaultParameterName
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.TEMPERATURE: {
"label": {
"en_US": "Temperature",
"zh_Hans": "温度",
},
"type": "float",
"help": {
"en_US": "Controls randomness. Lower temperature results in less random completions."
" As the temperature approaches zero, the model will become deterministic and repetitive."
" Higher temperature results in more random completions.",
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。"
"较高的温度会导致更多的随机完成。",
},
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.TOP_P: {
"label": {
"en_US": "Top P",
"zh_Hans": "Top P",
},
"type": "float",
"help": {
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
" are considered.",
"zh_Hans": "通过核心采样控制多样性0.5 表示考虑了一半的所有可能性加权选项。",
},
"required": False,
"default": 1.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.TOP_K: {
"label": {
"en_US": "Top K",
"zh_Hans": "Top K",
},
"type": "int",
"help": {
"en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.",
"zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。",
},
"required": False,
"default": 50,
"min": 1,
"max": 100,
"precision": 0,
},
DefaultParameterName.PRESENCE_PENALTY: {
"label": {
"en_US": "Presence Penalty",
"zh_Hans": "存在惩罚",
},
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
},
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.FREQUENCY_PENALTY: {
"label": {
"en_US": "Frequency Penalty",
"zh_Hans": "频率惩罚",
},
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
},
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.MAX_TOKENS: {
"label": {
"en_US": "Max Tokens",
"zh_Hans": "最大 Token 数",
},
"type": "int",
"help": {
"en_US": "Specifies the upper limit on the length of generated results."
" If the generated results are truncated, you can increase this parameter.",
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
},
"required": False,
"default": 64,
"min": 1,
"max": 2048,
"precision": 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
"label": {
"en_US": "Response Format",
"zh_Hans": "回复格式",
},
"type": "string",
"help": {
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
" such as JSON, XML, etc.",
"zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等",
},
"required": False,
"options": ["JSON", "XML"],
},
DefaultParameterName.JSON_SCHEMA: {
"label": {
"en_US": "JSON Schema",
},
"type": "text",
"help": {
"en_US": "Set a response json schema will ensure LLM to adhere it.",
"zh_Hans": "设置返回的 json schemallm 将按照它返回",
},
"required": False,
},
}

View File

@ -0,0 +1,219 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from decimal import Decimal
from enum import StrEnum
from typing import Any, TypedDict, Union
from pydantic import BaseModel, Field
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from dify_graph.model_runtime.entities.model_entities import ModelUsage, PriceInfo
class LLMMode(StrEnum):
"""
Enum class for large language model mode.
"""
COMPLETION = "completion"
CHAT = "chat"
class LLMUsageMetadata(TypedDict, total=False):
"""
TypedDict for LLM usage metadata.
All fields are optional.
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_unit_price: Union[float, str]
completion_unit_price: Union[float, str]
total_price: Union[float, str]
currency: str
prompt_price_unit: Union[float, str]
completion_price_unit: Union[float, str]
prompt_price: Union[float, str]
completion_price: Union[float, str]
latency: float
time_to_first_token: float
time_to_generate: float
class LLMUsage(ModelUsage):
"""
Model class for llm usage.
"""
prompt_tokens: int
prompt_unit_price: Decimal
prompt_price_unit: Decimal
prompt_price: Decimal
completion_tokens: int
completion_unit_price: Decimal
completion_price_unit: Decimal
completion_price: Decimal
total_tokens: int
total_price: Decimal
currency: str
latency: float
time_to_first_token: float | None = None
time_to_generate: float | None = None
@classmethod
def empty_usage(cls):
return cls(
prompt_tokens=0,
prompt_unit_price=Decimal("0.0"),
prompt_price_unit=Decimal("0.0"),
prompt_price=Decimal("0.0"),
completion_tokens=0,
completion_unit_price=Decimal("0.0"),
completion_price_unit=Decimal("0.0"),
completion_price=Decimal("0.0"),
total_tokens=0,
total_price=Decimal("0.0"),
currency="USD",
latency=0.0,
time_to_first_token=None,
time_to_generate=None,
)
@classmethod
def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
"""
Create LLMUsage instance from metadata dictionary with default values.
Args:
metadata: TypedDict containing usage metadata
Returns:
LLMUsage instance with values from metadata or defaults
"""
prompt_tokens = metadata.get("prompt_tokens", 0)
completion_tokens = metadata.get("completion_tokens", 0)
total_tokens = metadata.get("total_tokens", 0)
# If total_tokens is not provided but prompt and completion tokens are,
# calculate total_tokens
if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
total_tokens = prompt_tokens + completion_tokens
return cls(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
total_price=Decimal(str(metadata.get("total_price", 0))),
currency=metadata.get("currency", "USD"),
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
completion_price=Decimal(str(metadata.get("completion_price", 0))),
latency=metadata.get("latency", 0.0),
time_to_first_token=metadata.get("time_to_first_token"),
time_to_generate=metadata.get("time_to_generate"),
)
def plus(self, other: LLMUsage) -> LLMUsage:
"""
Add two LLMUsage instances together.
:param other: Another LLMUsage instance to add
:return: A new LLMUsage instance with summed values
"""
if self.total_tokens == 0:
return other
else:
return LLMUsage(
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
prompt_unit_price=other.prompt_unit_price,
prompt_price_unit=other.prompt_price_unit,
prompt_price=self.prompt_price + other.prompt_price,
completion_tokens=self.completion_tokens + other.completion_tokens,
completion_unit_price=other.completion_unit_price,
completion_price_unit=other.completion_price_unit,
completion_price=self.completion_price + other.completion_price,
total_tokens=self.total_tokens + other.total_tokens,
total_price=self.total_price + other.total_price,
currency=other.currency,
latency=self.latency + other.latency,
time_to_first_token=other.time_to_first_token,
time_to_generate=other.time_to_generate,
)
def __add__(self, other: LLMUsage) -> LLMUsage:
"""
Overload the + operator to add two LLMUsage instances.
:param other: Another LLMUsage instance to add
:return: A new LLMUsage instance with summed values
"""
return self.plus(other)
class LLMResult(BaseModel):
"""
Model class for llm result.
"""
id: str | None = None
model: str
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
message: AssistantPromptMessage
usage: LLMUsage
system_fingerprint: str | None = None
reasoning_content: str | None = None
class LLMStructuredOutput(BaseModel):
"""
Model class for llm structured output.
"""
structured_output: Mapping[str, Any] | None = None
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
"""
Model class for llm result with structured output.
"""
class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
"""
index: int
message: AssistantPromptMessage
usage: LLMUsage | None = None
finish_reason: str | None = None
class LLMResultChunk(BaseModel):
"""
Model class for llm result chunk.
"""
model: str
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
system_fingerprint: str | None = None
delta: LLMResultChunkDelta
class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
"""
Model class for llm result chunk with structured output.
"""
class NumTokensResult(PriceInfo):
"""
Model class for number of tokens result.
"""
tokens: int

View File

@ -0,0 +1,282 @@
from __future__ import annotations
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import Annotated, Any, Literal, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
class PromptMessageRole(StrEnum):
"""
Enum class for prompt message.
"""
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
@classmethod
def value_of(cls, value: str) -> PromptMessageRole:
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid prompt message type value {value}")
class PromptMessageTool(BaseModel):
"""
Model class for prompt message tool.
"""
name: str
description: str
parameters: dict
class PromptMessageFunction(BaseModel):
"""
Model class for prompt message function.
"""
type: str = "function"
function: PromptMessageTool
class PromptMessageContentType(StrEnum):
"""
Enum class for prompt message content type.
"""
TEXT = auto()
IMAGE = auto()
AUDIO = auto()
VIDEO = auto()
DOCUMENT = auto()
class PromptMessageContent(ABC, BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
class TextPromptMessageContent(PromptMessageContent):
"""
Model class for text prompt message content.
"""
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
data: str
class MultiModalPromptMessageContent(PromptMessageContent):
"""
Model class for multi-modal prompt message content.
"""
format: str = Field(default=..., description="the format of multi-modal file")
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
filename: str = Field(default="", description="the filename of multi-modal file")
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(StrEnum):
LOW = auto()
HIGH = auto()
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
PromptMessageContentUnionTypes = Annotated[
Union[
TextPromptMessageContent,
ImagePromptMessageContent,
DocumentPromptMessageContent,
AudioPromptMessageContent,
VideoPromptMessageContent,
],
Field(discriminator="type"),
]
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
PromptMessageContentType.TEXT: TextPromptMessageContent,
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
}
class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole
content: str | list[PromptMessageContentUnionTypes] | None = None
name: str | None = None
def is_empty(self) -> bool:
"""
Check if prompt message is empty.
:return: True if prompt message is empty, False otherwise
"""
return not self.content
def get_text_content(self) -> str:
"""
Get text content from prompt message.
:return: Text content as string, empty string if no text content
"""
if isinstance(self.content, str):
return self.content
elif isinstance(self.content, list):
text_parts = []
for item in self.content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return "".join(text_parts)
else:
return ""
@field_validator("content", mode="before")
@classmethod
def validate_content(cls, v):
if isinstance(v, list):
prompts = []
for prompt in v:
if isinstance(prompt, PromptMessageContent):
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
elif isinstance(prompt, dict):
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
else:
raise ValueError(f"invalid prompt message {prompt}")
prompts.append(prompt)
return prompts
return v
@field_serializer("content")
def serialize_content(
self, content: Union[str, Sequence[PromptMessageContent]] | None
) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None:
if content is None or isinstance(content, str):
return content
if isinstance(content, list):
return [item.model_dump() if hasattr(item, "model_dump") else item for item in content]
return content
class UserPromptMessage(PromptMessage):
"""
Model class for user prompt message.
"""
role: PromptMessageRole = PromptMessageRole.USER
class AssistantPromptMessage(PromptMessage):
"""
Model class for assistant prompt message.
"""
class ToolCall(BaseModel):
"""
Model class for assistant prompt message tool call.
"""
class ToolCallFunction(BaseModel):
"""
Model class for assistant prompt message tool call function.
"""
name: str
arguments: str
id: str
type: str
function: ToolCallFunction
@field_validator("id", mode="before")
@classmethod
def transform_id_to_str(cls, value) -> str:
if not isinstance(value, str):
return str(value)
else:
return value
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = []
def is_empty(self) -> bool:
"""
Check if prompt message is empty.
:return: True if prompt message is empty, False otherwise
"""
return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):
"""
Model class for system prompt message.
"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
class ToolPromptMessage(PromptMessage):
"""
Model class for tool prompt message.
"""
role: PromptMessageRole = PromptMessageRole.TOOL
tool_call_id: str
def is_empty(self) -> bool:
"""
Check if prompt message is empty.
:return: True if prompt message is empty, False otherwise
"""
if not super().is_empty() and not self.tool_call_id:
return False
return True

View File

@ -0,0 +1,242 @@
from __future__ import annotations
from decimal import Decimal
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, ConfigDict, model_validator
from dify_graph.model_runtime.entities.common_entities import I18nObject
class ModelType(StrEnum):
"""
Enum class for model type.
"""
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
TTS = auto()
@classmethod
def value_of(cls, origin_model_type: str) -> ModelType:
"""
Get model type from origin model type.
:return: model type
"""
if origin_model_type in {"text-generation", cls.LLM}:
return cls.LLM
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
return cls.TEXT_EMBEDDING
elif origin_model_type in {"reranking", cls.RERANK}:
return cls.RERANK
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
return cls.SPEECH2TEXT
elif origin_model_type in {"tts", cls.TTS}:
return cls.TTS
elif origin_model_type == cls.MODERATION:
return cls.MODERATION
else:
raise ValueError(f"invalid origin model type {origin_model_type}")
def to_origin_model_type(self) -> str:
"""
Get origin model type from model type.
:return: origin model type
"""
if self == self.LLM:
return "text-generation"
elif self == self.TEXT_EMBEDDING:
return "embeddings"
elif self == self.RERANK:
return "reranking"
elif self == self.SPEECH2TEXT:
return "speech2text"
elif self == self.TTS:
return "tts"
elif self == self.MODERATION:
return "moderation"
else:
raise ValueError(f"invalid model type {self}")
class FetchFrom(StrEnum):
"""
Enum class for fetch from.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
class ModelFeature(StrEnum):
"""
Enum class for llm feature.
"""
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()
STRUCTURED_OUTPUT = "structured-output"
class DefaultParameterName(StrEnum):
"""
Enum class for parameter template variable.
"""
TEMPERATURE = auto()
TOP_P = auto()
TOP_K = auto()
PRESENCE_PENALTY = auto()
FREQUENCY_PENALTY = auto()
MAX_TOKENS = auto()
RESPONSE_FORMAT = auto()
JSON_SCHEMA = auto()
@classmethod
def value_of(cls, value: Any) -> DefaultParameterName:
"""
Get parameter name from value.
:param value: parameter value
:return: parameter name
"""
for name in cls:
if name.value == value:
return name
raise ValueError(f"invalid parameter name {value}")
class ParameterType(StrEnum):
"""
Enum class for parameter type.
"""
FLOAT = auto()
INT = auto()
STRING = auto()
BOOLEAN = auto()
TEXT = auto()
class ModelPropertyKey(StrEnum):
"""
Enum class for model property key.
"""
MODE = auto()
CONTEXT_SIZE = auto()
MAX_CHUNKS = auto()
FILE_UPLOAD_LIMIT = auto()
SUPPORTED_FILE_EXTENSIONS = auto()
MAX_CHARACTERS_PER_CHUNK = auto()
DEFAULT_VOICE = auto()
VOICES = auto()
WORD_LIMIT = auto()
AUDIO_TYPE = auto()
MAX_WORKERS = auto()
class ProviderModel(BaseModel):
"""
Model class for provider model.
"""
model: str
label: I18nObject
model_type: ModelType
features: list[ModelFeature] | None = None
fetch_from: FetchFrom
model_properties: dict[ModelPropertyKey, Any]
deprecated: bool = False
model_config = ConfigDict(protected_namespaces=())
@property
def support_structure_output(self) -> bool:
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
class ParameterRule(BaseModel):
"""
Model class for parameter rule.
"""
name: str
use_template: str | None = None
label: I18nObject
type: ParameterType
help: I18nObject | None = None
required: bool = False
default: Any | None = None
min: float | None = None
max: float | None = None
precision: int | None = None
options: list[str] = []
class PriceConfig(BaseModel):
"""
Model class for pricing info.
"""
input: Decimal
output: Decimal | None = None
unit: Decimal
currency: str
class AIModelEntity(ProviderModel):
"""
Model class for AI model.
"""
parameter_rules: list[ParameterRule] = []
pricing: PriceConfig | None = None
@model_validator(mode="after")
def validate_model(self):
supported_schema_keys = ["json_schema"]
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
if not schema_key:
return self
if self.features is None:
self.features = [ModelFeature.STRUCTURED_OUTPUT]
else:
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
return self
class ModelUsage(BaseModel):
pass
class PriceType(StrEnum):
"""
Enum class for price type.
"""
INPUT = auto()
OUTPUT = auto()
class PriceInfo(BaseModel):
"""
Model class for price info.
"""
unit_price: Decimal
unit: Decimal
total_amount: Decimal
currency: str

View File

@ -0,0 +1,169 @@
from collections.abc import Sequence
from enum import StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
class ConfigurateMethod(StrEnum):
"""
Enum class for configurate method of provider model.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
class FormType(StrEnum):
"""
Enum class for form type.
"""
TEXT_INPUT = "text-input"
SECRET_INPUT = "secret-input"
SELECT = auto()
RADIO = auto()
SWITCH = auto()
class FormShowOnObject(BaseModel):
"""
Model class for form show on.
"""
variable: str
value: str
class FormOption(BaseModel):
"""
Model class for form option.
"""
label: I18nObject
value: str
show_on: list[FormShowOnObject] = []
@model_validator(mode="after")
def _(self):
if not self.label:
self.label = I18nObject(en_US=self.value)
return self
class CredentialFormSchema(BaseModel):
"""
Model class for credential form schema.
"""
variable: str
label: I18nObject
type: FormType
required: bool = True
default: str | None = None
options: list[FormOption] | None = None
placeholder: I18nObject | None = None
max_length: int = 0
show_on: list[FormShowOnObject] = []
class ProviderCredentialSchema(BaseModel):
"""
Model class for provider credential schema.
"""
credential_form_schemas: list[CredentialFormSchema]
class FieldModelSchema(BaseModel):
label: I18nObject
placeholder: I18nObject | None = None
class ModelCredentialSchema(BaseModel):
"""
Model class for model credential schema.
"""
model: FieldModelSchema
credential_form_schemas: list[CredentialFormSchema]
class SimpleProviderEntity(BaseModel):
"""
Simple model class for provider.
"""
provider: str
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
class ProviderHelpEntity(BaseModel):
"""
Model class for provider help.
"""
title: I18nObject
url: I18nObject
class ProviderEntity(BaseModel):
"""
Model class for provider.
"""
provider: str
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod]
models: list[AIModelEntity] = Field(default_factory=list)
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
# position from plugin _position.yaml
position: dict[str, list[str]] | None = {}
@field_validator("models", mode="before")
@classmethod
def validate_models(cls, v):
# returns EmptyList if v is empty
if not v:
return []
return v
def to_simple_provider(self) -> SimpleProviderEntity:
"""
Convert to simple provider.
:return: simple provider
"""
return SimpleProviderEntity(
provider=self.provider,
label=self.label,
icon_small=self.icon_small,
supported_model_types=self.supported_model_types,
models=self.models,
)
class ProviderConfig(BaseModel):
"""
Model class for provider config.
"""
provider: str
credentials: dict

View File

@ -0,0 +1,20 @@
from pydantic import BaseModel
class RerankDocument(BaseModel):
"""
Model class for rerank document.
"""
index: int
text: str
score: float
class RerankResult(BaseModel):
"""
Model class for rerank result.
"""
model: str
docs: list[RerankDocument]

View File

@ -0,0 +1,39 @@
from decimal import Decimal
from pydantic import BaseModel
from dify_graph.model_runtime.entities.model_entities import ModelUsage
class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.
"""
tokens: int
total_tokens: int
unit_price: Decimal
price_unit: Decimal
total_price: Decimal
currency: str
latency: float
class EmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage
class FileEmbeddingResult(BaseModel):
"""
Model class for file embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage

View File

@ -0,0 +1,40 @@
class InvokeError(ValueError):
"""Base class for all LLM exceptions."""
description: str | None = None
def __init__(self, description: str | None = None):
self.description = description
def __str__(self):
return self.description or self.__class__.__name__
class InvokeConnectionError(InvokeError):
"""Raised when the Invoke returns connection error."""
description = "Connection Error"
class InvokeServerUnavailableError(InvokeError):
"""Raised when the Invoke returns server unavailable error."""
description = "Server Unavailable Error"
class InvokeRateLimitError(InvokeError):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"
class InvokeAuthorizationError(InvokeError):
"""Raised when the Invoke returns authorization error."""
description = "Incorrect model credentials provided, please check and try again. "
class InvokeBadRequestError(InvokeError):
"""Raised when the Invoke returns bad request."""
description = "Bad Request Error"

View File

@ -0,0 +1,6 @@
class CredentialsValidateFailedError(ValueError):
"""
Credentials validate failed error
"""
pass

View File

@ -0,0 +1,3 @@
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol
from dify_graph.model_runtime.entities import PromptMessage
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages."""
def get_history_prompt_messages(
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@ -0,0 +1,286 @@
import decimal
import hashlib
import logging
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from redis import RedisError
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from dify_graph.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
ModelType,
PriceConfig,
PriceInfo,
PriceType,
)
from dify_graph.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class AIModel(BaseModel):
"""
Base class for all models.
"""
tenant_id: str = Field(description="Tenant ID")
model_type: ModelType = Field(description="Model type")
plugin_id: str = Field(description="Plugin ID")
provider_name: str = Field(description="Provider")
plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider")
started_at: float = Field(description="Invoke start time", default=0)
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@property
def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError],
PluginDaemonInnerError: [PluginDaemonInnerError],
ValueError: [ValueError],
}
def _transform_invoke_error(self, error: Exception) -> Exception:
"""
Transform invoke error to unified error
:param error: model invoke error
:return: unified error
"""
for invoke_error, model_errors in self._invoke_error_mapping.items():
if isinstance(error, tuple(model_errors)):
if invoke_error == InvokeAuthorizationError:
return InvokeAuthorizationError(
description=(
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
)
)
elif isinstance(invoke_error, InvokeError):
return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
else:
return error
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
"""
Get price for given model and tokens
:param model: model name
:param credentials: model credentials
:param price_type: price type
:param tokens: number of tokens
:return: price info
"""
# get model schema
model_schema = self.get_model_schema(model, credentials)
# get price info from predefined model schema
price_config: PriceConfig | None = None
if model_schema and model_schema.pricing:
price_config = model_schema.pricing
# get unit price
unit_price = None
if price_config:
if price_type == PriceType.INPUT:
unit_price = price_config.input
elif price_type == PriceType.OUTPUT and price_config.output is not None:
unit_price = price_config.output
if unit_price is None:
return PriceInfo(
unit_price=decimal.Decimal("0.0"),
unit=decimal.Decimal("0.0"),
total_amount=decimal.Decimal("0.0"),
currency="USD",
)
# calculate total amount
if not price_config:
raise ValueError(f"Price config not found for model {model}")
total_amount = tokens * unit_price * price_config.unit
total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP)
return PriceInfo(
unit_price=unit_price,
unit=price_config.unit,
total_amount=total_amount,
currency=price_config.currency,
)
def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None:
"""
Get model schema by model name and credentials
:param model: model name
:param credentials: model credentials
:return: model schema
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema from credentials
:param model: model name
:param credentials: model credentials
:return: model schema
"""
# get customizable model schema
schema = self.get_customizable_model_schema(model, credentials)
if not schema:
return None
# fill in the template
new_parameter_rules = []
for parameter_rule in schema.parameter_rules:
if parameter_rule.use_template:
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max and "max" in default_parameter_rule:
parameter_rule.max = default_parameter_rule["max"]
if not parameter_rule.min and "min" in default_parameter_rule:
parameter_rule.min = default_parameter_rule["min"]
if not parameter_rule.default and "default" in default_parameter_rule:
parameter_rule.default = default_parameter_rule["default"]
if not parameter_rule.precision and "precision" in default_parameter_rule:
parameter_rule.precision = default_parameter_rule["precision"]
if not parameter_rule.required and "required" in default_parameter_rule:
parameter_rule.required = default_parameter_rule["required"]
if not parameter_rule.help and "help" in default_parameter_rule:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule["help"]["en_US"],
)
if (
parameter_rule.help
and not parameter_rule.help.en_US
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
):
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
if (
parameter_rule.help
and not parameter_rule.help.zh_Hans
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
):
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
"zh_Hans", default_parameter_rule["help"]["en_US"]
)
except ValueError:
pass
new_parameter_rules.append(parameter_rule)
schema.parameter_rules = new_parameter_rules
return schema
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
return None
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName):
"""
Get default parameter rule for given name
:param name: parameter name
:return: parameter rule
"""
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
if not default_parameter_rule:
raise Exception(f"Invalid model parameter rule name {name}")
return default_parameter_rule

View File

@ -0,0 +1,668 @@
import logging
import time
import uuid
from collections.abc import Callable, Generator, Iterator, Sequence
from typing import Union
from pydantic import ConfigDict
from configs import dify_config
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentUnionTypes,
PromptMessageTool,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.model_entities import (
ModelType,
PriceType,
)
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
def _gen_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None:
if not callbacks:
return
for callback in callbacks:
try:
invoke(callback)
except Exception as e:
if callback.raise_error:
raise
logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e)
def _get_or_create_tool_call(
existing_tools_calls: list[AssistantPromptMessage.ToolCall],
tool_call_id: str,
) -> AssistantPromptMessage.ToolCall:
"""
Get or create a tool call by ID.
If `tool_call_id` is empty, returns the most recently created tool call.
"""
if not tool_call_id:
if not existing_tools_calls:
raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta")
return existing_tools_calls[-1]
tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
existing_tools_calls.append(tool_call)
return tool_call
def _merge_tool_call_delta(
tool_call: AssistantPromptMessage.ToolCall,
delta: AssistantPromptMessage.ToolCall,
) -> None:
if delta.id:
tool_call.id = delta.id
if delta.type:
tool_call.type = delta.type
if delta.function.name:
tool_call.function.name = delta.function.name
if delta.function.arguments:
tool_call.function.arguments += delta.function.arguments
def _build_llm_result_from_chunks(
model: str,
prompt_messages: Sequence[PromptMessage],
chunks: Iterator[LLMResultChunk],
) -> LLMResult:
"""
Build a single `LLMResult` by accumulating all returned chunks.
Some models only support streaming output (e.g. Qwen3 open-source edition)
and the plugin side may still implement the response via a chunked stream,
so all chunks must be consumed and concatenated into a single ``LLMResult``.
The ``usage`` is taken from the last chunk that carries it, which is the
typical convention for streaming responses (the final chunk contains the
aggregated token counts).
"""
content = ""
content_list: list[PromptMessageContentUnionTypes] = []
usage = LLMUsage.empty_usage()
system_fingerprint: str | None = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []
try:
for chunk in chunks:
if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls:
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
if chunk.delta.usage:
usage = chunk.delta.usage
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception:
logger.exception("Error while consuming non-stream plugin chunk iterator.")
raise
finally:
# Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections).
close = getattr(chunks, "close", None)
if callable(close):
close()
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=content or content_list,
tool_calls=tools_calls,
),
usage=usage,
system_fingerprint=system_fingerprint,
)
def _invoke_llm_via_plugin(
*,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
model_parameters: dict,
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_llm(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=list(prompt_messages),
tools=tools,
stop=list(stop) if stop else None,
stream=stream,
)
def _normalize_non_stream_plugin_result(
model: str,
prompt_messages: Sequence[PromptMessage],
result: Union[LLMResult, Iterator[LLMResultChunk]],
) -> LLMResult:
if isinstance(result, LLMResult):
return result
return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result)
def _increase_tool_call(
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
):
"""
Merge incremental tool call updates into existing tool calls.
:param new_tool_calls: List of new tool call deltas to be merged.
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
"""
for new_tool_call in new_tool_calls:
# generate ID for tool calls with function name but no ID to track them
if new_tool_call.function.name and not new_tool_call.id:
new_tool_call.id = _gen_tool_call_id()
tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id)
_merge_tool_call_delta(tool_call, new_tool_call)
class LargeLanguageModel(AIModel):
"""
Model class for large language model.
"""
model_type: ModelType = ModelType.LLM
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
# validate and filter model parameters
if model_parameters is None:
model_parameters = {}
self.started_at = time.perf_counter()
callbacks = callbacks or []
if dify_config.DEBUG:
callbacks.append(LoggingCallback())
# trigger before invoke callbacks
self._trigger_before_invoke_callbacks(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
result = _invoke_llm_via_plugin(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=prompt_messages,
tools=tools,
stop=stop,
stream=stream,
)
if not stream:
result = _normalize_non_stream_plugin_result(
model=model, prompt_messages=prompt_messages, result=result
)
except Exception as e:
self._trigger_invoke_error_callbacks(
model=model,
ex=e,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
# TODO
raise self._transform_invoke_error(e)
if stream and not isinstance(result, LLMResult):
return self._invoke_result_generator(
model=model,
result=result,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
elif isinstance(result, LLMResult):
self._trigger_after_invoke_callbacks(
model=model,
result=result,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
# Following https://github.com/langgenius/dify/issues/17799,
# we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
result.prompt_messages = prompt_messages
return result
raise NotImplementedError("unsupported invoke result type", type(result))
def _invoke_result_generator(
self,
model: str,
result: Generator[LLMResultChunk, None, None],
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Invoke result generator
:param result: result generator
:return: result generator
"""
callbacks = callbacks or []
message_content: list[PromptMessageContentUnionTypes] = []
usage = None
system_fingerprint = None
real_model = model
def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
if not content:
return
if isinstance(content, list):
message_content.extend(content)
return
if isinstance(content, str):
message_content.append(TextPromptMessageContent(data=content))
return
try:
for chunk in result:
# Following https://github.com/langgenius/dify/issues/17799,
# we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
chunk.prompt_messages = prompt_messages
yield chunk
self._trigger_new_chunk_callbacks(
chunk=chunk,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
_update_message_content(chunk.delta.message.content)
real_model = chunk.model
if chunk.delta.usage:
usage = chunk.delta.usage
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception as e:
raise self._transform_invoke_error(e)
assistant_message = AssistantPromptMessage(content=message_content)
self._trigger_after_invoke_callbacks(
model=model,
result=LLMResult(
model=real_model,
prompt_messages=prompt_messages,
message=assistant_message,
usage=usage or LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint,
),
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
tools=tools,
)
return 0
def calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
) -> LLMUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param prompt_tokens: prompt tokens
:param completion_tokens: completion tokens
:return: usage
"""
# get prompt price info
prompt_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=prompt_tokens,
)
# get completion price info
completion_price_info = self.get_price(
model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens
)
# transform usage
usage = LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=prompt_price_info.unit_price,
prompt_price_unit=prompt_price_info.unit,
prompt_price=prompt_price_info.total_amount,
completion_tokens=completion_tokens,
completion_unit_price=completion_price_info.unit_price,
completion_price_unit=completion_price_info.unit,
completion_price=completion_price_info.total_amount,
total_tokens=prompt_tokens + completion_tokens,
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
currency=prompt_price_info.currency,
latency=time.perf_counter() - self.started_at,
)
return usage
def _trigger_before_invoke_callbacks(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger before invoke callbacks
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
"""
_run_callbacks(
callbacks,
event="on_before_invoke",
invoke=lambda callback: callback.on_before_invoke(
llm_instance=self,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
),
)
def _trigger_new_chunk_callbacks(
self,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger new chunk callbacks
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
_run_callbacks(
callbacks,
event="on_new_chunk",
invoke=lambda callback: callback.on_new_chunk(
llm_instance=self,
chunk=chunk,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
),
)
def _trigger_after_invoke_callbacks(
self,
model: str,
result: LLMResult,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger after invoke callbacks
:param model: model name
:param result: result
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
"""
_run_callbacks(
callbacks,
event="on_after_invoke",
invoke=lambda callback: callback.on_after_invoke(
llm_instance=self,
result=result,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
),
)
def _trigger_invoke_error_callbacks(
self,
model: str,
ex: Exception,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger invoke error callbacks
:param model: model name
:param ex: exception
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
"""
_run_callbacks(
callbacks,
event="on_invoke_error",
invoke=lambda callback: callback.on_invoke_error(
llm_instance=self,
ex=ex,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
),
)

View File

@ -0,0 +1,45 @@
import time
from pydantic import ConfigDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
class ModerationModel(AIModel):
"""
Model class for moderation model.
"""
model_type: ModelType = ModelType.MODERATION
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
"""
Invoke moderation model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
self.started_at = time.perf_counter()
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_moderation(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
text=text,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@ -0,0 +1,92 @@
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
class RerankModel(AIModel):
"""
Base Model class for rerank model.
"""
model_type: ModelType = ModelType.RERANK
def invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
except Exception as e:
raise self._transform_invoke_error(e)
def invoke_multimodal_rerank(
self,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke multimodal rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@ -0,0 +1,43 @@
from typing import IO
from pydantic import ConfigDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
class Speech2TextModel(AIModel):
"""
Model class for speech2text model.
"""
model_type: ModelType = ModelType.SPEECH2TEXT
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
"""
Invoke speech to text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
file=file,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@ -0,0 +1,121 @@
from pydantic import ConfigDict
from core.entities.embedding_type import EmbeddingInputType
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
class TextEmbeddingModel(AIModel):
"""
Model class for text embedding model.
"""
model_type: ModelType = ModelType.TEXT_EMBEDDING
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(
self,
model: str,
credentials: dict,
texts: list[str] | None = None,
multimodel_documents: list[dict] | None = None,
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param files: files to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
from core.plugin.impl.model import PluginModelClient
try:
plugin_model_manager = PluginModelClient()
if texts:
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
if multimodel_documents:
return plugin_model_manager.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
documents=multimodel_documents,
input_type=input_type,
)
raise ValueError("No texts or files provided")
except Exception as e:
raise self._transform_invoke_error(e)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
texts=texts,
)
def _get_context_size(self, model: str, credentials: dict) -> int:
"""
Get context size for given embedding model
:param model: model name
:param credentials: model credentials
:return: context size
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
return content_size
return 1000
def _get_max_chunks(self, model: str, credentials: dict) -> int:
"""
Get max chunks for given embedding model
:param model: model name
:param credentials: model credentials
:return: max chunks
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
return max_chunks
return 1

View File

@ -0,0 +1,53 @@
import logging
from threading import Lock
from typing import Any
logger = logging.getLogger(__name__)
_tokenizer: Any | None = None
_lock = Lock()
class GPT2Tokenizer:
@staticmethod
def _get_num_tokens_by_gpt2(text: str) -> int:
"""
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text) # type: ignore
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
#
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
def get_encoder():
global _tokenizer, _lock
if _tokenizer is not None:
return _tokenizer
with _lock:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster
#
try:
import tiktoken
_tokenizer = tiktoken.get_encoding("gpt2")
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
return _tokenizer

View File

@ -0,0 +1,79 @@
import logging
from collections.abc import Iterable
from pydantic import ConfigDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class TTSModel(AIModel):
"""
Model class for TTS model.
"""
model_type: ModelType = ModelType.TTS
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: str | None = None,
) -> Iterable[bytes]:
"""
Invoke large language model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: translated audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_tts(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)
except Exception as e:
raise self._transform_invoke_error(e)
def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None):
"""
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
:param language: The language for which the voices are requested.
:param model: The name of the TTS model.
:param credentials: The credentials required to access the TTS model.
:return: A list of voices supported by the TTS model.
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
language=language,
)

View File

@ -0,0 +1,43 @@
- openai
- deepseek
- anthropic
- azure_openai
- google
- vertex_ai
- nvidia
- nvidia_nim
- cohere
- upstage
- bedrock
- togetherai
- openrouter
- ollama
- mistralai
- groq
- replicate
- huggingface_hub
- xinference
- triton_inference_server
- zhipuai
- baichuan
- spark
- minimax
- tongyi
- wenxin
- moonshot
- tencent
- jina
- chatglm
- yi
- openllm
- localai
- volcengine_maas
- openai_api_compatible
- hunyuan
- siliconflow
- perfxcloud
- zhinao
- fireworks
- mixedbread
- nomic
- voyage

View File

@ -0,0 +1,386 @@
from __future__ import annotations
import hashlib
import logging
from collections.abc import Sequence
from threading import Lock
from pydantic import ValidationError
from redis import RedisError
import contexts
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import (
ProviderCredentialSchemaValidator,
)
from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
class ModelProviderFactory:
def __init__(self, tenant_id: str):
from core.plugin.impl.model import PluginModelClient
self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelClient()
def get_providers(self) -> Sequence[ProviderEntity]:
"""
Get all providers
:return: list of providers
"""
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
# The plugin server should return providers in the desired order
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
"""
Get all plugin model providers
:return: list of plugin model providers
"""
# check if context is set
try:
contexts.plugin_model_providers.get()
except LookupError:
contexts.plugin_model_providers.set(None)
contexts.plugin_model_providers_lock.set(Lock())
with contexts.plugin_model_providers_lock.get():
plugin_model_providers = contexts.plugin_model_providers.get()
if plugin_model_providers is not None:
return plugin_model_providers
plugin_model_providers = []
contexts.plugin_model_providers.set(plugin_model_providers)
# Fetch plugin model providers
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
for provider in plugin_providers:
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
plugin_model_providers.append(provider)
return plugin_model_providers
def get_provider_schema(self, provider: str) -> ProviderEntity:
"""
Get provider schema
:param provider: provider name
:return: provider schema
"""
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
"""
Get plugin model provider
:param provider: provider name
:return: provider schema
"""
if "/" not in provider:
provider = str(ModelProviderID(provider))
# fetch plugin model providers
plugin_model_provider_entities = self.get_plugin_model_providers()
# get the provider
plugin_model_provider_entity = next(
(p for p in plugin_model_provider_entities if p.declaration.provider == provider),
None,
)
if not plugin_model_provider_entity:
raise ValueError(f"Invalid provider: {provider}")
return plugin_model_provider_entity
def provider_credentials_validate(self, *, provider: str, credentials: dict):
"""
Validate provider credentials
:param provider: provider name
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
:return:
"""
# fetch plugin model provider
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
# get provider_credential_schema and validate credentials according to the rules
provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema
if not provider_credential_schema:
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
# validate provider credential schema
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)
# validate the credentials, raise exception if validation failed
self.plugin_model_manager.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_model_provider_entity.plugin_id,
provider=plugin_model_provider_entity.provider,
credentials=filtered_credentials,
)
return filtered_credentials
def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict):
"""
Validate model credentials
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials, credentials form defined in `model_credential_schema`.
:return:
"""
# fetch plugin model provider
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
# get model_credential_schema and validate credentials according to the rules
model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema
if not model_credential_schema:
raise ValueError(f"Provider {provider} does not have model_credential_schema")
# validate model credential schema
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)
# call validate_credentials method of model type to validate credentials, raise exception if validation failed
self.plugin_model_manager.validate_model_credentials(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_model_provider_entity.plugin_id,
provider=plugin_model_provider_entity.provider,
model_type=model_type.value,
model=model,
credentials=filtered_credentials,
)
return filtered_credentials
def get_model_schema(
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
) -> AIModelEntity | None:
"""
Get model schema
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_models(
self,
*,
provider: str | None = None,
model_type: ModelType | None = None,
provider_configs: list[ProviderConfig] | None = None,
) -> list[SimpleProviderEntity]:
"""
Get all models for given model type
:param provider: provider name
:param model_type: model type
:param provider_configs: list of provider configs
:return: list of models
"""
provider_configs = provider_configs or []
# scan all providers
plugin_model_provider_entities = self.get_plugin_model_providers()
# traverse all model_provider_extensions
providers = []
for plugin_model_provider_entity in plugin_model_provider_entities:
# filter by provider if provider is present
if provider and plugin_model_provider_entity.declaration.provider != provider:
continue
# get provider schema
provider_schema = plugin_model_provider_entity.declaration
model_types = provider_schema.supported_model_types
if model_type:
if model_type not in model_types:
continue
model_types = [model_type]
all_model_type_models = []
for model_schema in provider_schema.models:
if model_schema.model_type != model_type:
continue
all_model_type_models.append(model_schema)
simple_provider_schema = provider_schema.to_simple_provider()
simple_provider_schema.models.extend(all_model_type_models)
providers.append(simple_provider_schema)
return providers
def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel:
"""
Get model type instance by provider name and model type
:param provider: provider name
:param model_type: model type
:return: model type instance
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
init_params = {
"tenant_id": self.tenant_id,
"plugin_id": plugin_id,
"provider_name": provider_name,
"plugin_model_provider": self.get_plugin_model_provider(provider),
}
if model_type == ModelType.LLM:
return LargeLanguageModel.model_validate(init_params)
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel.model_validate(init_params)
elif model_type == ModelType.RERANK:
return RerankModel.model_validate(init_params)
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel.model_validate(init_params)
elif model_type == ModelType.MODERATION:
return ModerationModel.model_validate(init_params)
elif model_type == ModelType.TTS:
return TTSModel.model_validate(init_params)
raise ValueError(f"Unsupported model type: {model_type}")
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""
Get provider icon
:param provider: provider name
:param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return: provider icon
"""
# get the provider schema
provider_schema = self.get_provider_schema(provider)
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_small.zh_Hans
else:
file_name = provider_schema.icon_small.en_US
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_small_dark.zh_Hans
else:
file_name = provider_schema.icon_small_dark.en_US
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
image_mime_types = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"bmp": "image/bmp",
"tiff": "image/tiff",
"tif": "image/tiff",
"webp": "image/webp",
"svg": "image/svg+xml",
"ico": "image/vnd.microsoft.icon",
"heif": "image/heif",
"heic": "image/heic",
}
extension = file_name.split(".")[-1]
mime_type = image_mime_types.get(extension, "image/png")
# get icon bytes from plugin asset manager
from core.plugin.impl.asset import PluginAssetManager
plugin_asset_manager = PluginAssetManager()
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]:
"""
Get plugin id and provider name from provider name
:param provider: provider name
:return: plugin id and provider name
"""
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name

View File

@ -0,0 +1,92 @@
from typing import Union, cast
from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType
class CommonValidator:
def _validate_and_filter_credential_form_schemas(
self, credential_form_schemas: list[CredentialFormSchema], credentials: dict
):
need_validate_credential_form_schema_map = {}
for credential_form_schema in credential_form_schemas:
if not credential_form_schema.show_on:
need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema
continue
all_show_on_match = True
for show_on_object in credential_form_schema.show_on:
if show_on_object.variable not in credentials:
all_show_on_match = False
break
if credentials[show_on_object.variable] != show_on_object.value:
all_show_on_match = False
break
if all_show_on_match:
need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema
# Iterate over the remaining credential_form_schemas, verify each credential_form_schema
validated_credentials = {}
for credential_form_schema in need_validate_credential_form_schema_map.values():
# add the value of the credential_form_schema corresponding to it to validated_credentials
result = self._validate_credential_form_schema(credential_form_schema, credentials)
if result:
validated_credentials[credential_form_schema.variable] = result
return validated_credentials
def _validate_credential_form_schema(
self, credential_form_schema: CredentialFormSchema, credentials: dict
) -> Union[str, bool, None]:
"""
Validate credential form schema
:param credential_form_schema: credential form schema
:param credentials: credentials
:return: validated credential form schema value
"""
# If the variable does not exist in credentials
value: Union[str, bool, None] = None
if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]:
# If required is True, an exception is thrown
if credential_form_schema.required:
raise ValueError(f"Variable {credential_form_schema.variable} is required")
else:
# Get the value of default
if credential_form_schema.default:
# If it exists, add it to validated_credentials
return credential_form_schema.default
else:
# If default does not exist, skip
return None
# Get the value corresponding to the variable from credentials
value = cast(str, credentials[credential_form_schema.variable])
# If max_length=0, no validation is performed
if credential_form_schema.max_length:
if len(value) > credential_form_schema.max_length:
raise ValueError(
f"Variable {credential_form_schema.variable} length should not be"
f" greater than {credential_form_schema.max_length}"
)
# check the type of value
if not isinstance(value, str):
raise ValueError(f"Variable {credential_form_schema.variable} should be string")
if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}:
# If the value is in options, no validation is performed
if credential_form_schema.options:
if value not in [option.value for option in credential_form_schema.options]:
raise ValueError(f"Variable {credential_form_schema.variable} is not in options")
if credential_form_schema.type == FormType.SWITCH:
# If the value is not in ['true', 'false'], an exception is thrown
if value.lower() not in {"true", "false"}:
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
value = value.lower() == "true"
return value

View File

@ -0,0 +1,27 @@
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import ModelCredentialSchema
from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator
class ModelCredentialSchemaValidator(CommonValidator):
def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema):
self.model_type = model_type
self.model_credential_schema = model_credential_schema
def validate_and_filter(self, credentials: dict):
"""
Validate model credentials
:param credentials: model credentials
:return: filtered credentials
"""
if self.model_credential_schema is None:
raise ValueError("Model credential schema is None")
# get the credential_form_schemas in provider_credential_schema
credential_form_schemas = self.model_credential_schema.credential_form_schemas
credentials["__model_type"] = self.model_type.value
return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials)

View File

@ -0,0 +1,19 @@
from dify_graph.model_runtime.entities.provider_entities import ProviderCredentialSchema
from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator
class ProviderCredentialSchemaValidator(CommonValidator):
def __init__(self, provider_credential_schema: ProviderCredentialSchema):
self.provider_credential_schema = provider_credential_schema
def validate_and_filter(self, credentials: dict):
"""
Validate provider credentials
:param credentials: provider credentials
:return: validated provider credentials
"""
# get the credential_form_schemas in provider_credential_schema
credential_form_schemas = self.provider_credential_schema.credential_form_schemas
return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials)

View File

@ -0,0 +1,216 @@
import dataclasses
import datetime
from collections import defaultdict, deque
from collections.abc import Callable
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path, PurePath
from re import Pattern
from types import GeneratorType
from typing import Any, Literal, Union
from uuid import UUID
from pydantic import BaseModel
from pydantic.networks import AnyUrl, NameEmail
from pydantic.types import SecretBytes, SecretStr
from pydantic_core import Url
from pydantic_extra_types.color import Color
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
return model.model_dump(mode=mode, **kwargs)
# Taken from Pydantic v1 as is
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
return o.isoformat()
# Taken from Pydantic v1 as is
# TODO: pv2 should this return strings instead?
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
"""
Encodes a Decimal as int of there's no exponent, otherwise float
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
where a integer (but not int typed) is used. Encoding this as a float
results in failed round-tripping between encode and parse.
Our Id type is a prime example of this.
>>> decimal_encoder(Decimal("1.0"))
1.0
>>> decimal_encoder(Decimal("1"))
1
"""
if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
return int(dec_value)
else:
return float(dec_value)
ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
datetime.date: isoformat,
datetime.datetime: isoformat,
datetime.time: isoformat,
datetime.timedelta: lambda td: td.total_seconds(),
Decimal: decimal_encoder,
Enum: lambda o: o.value,
frozenset: list,
deque: list,
GeneratorType: list,
IPv4Address: str,
IPv4Interface: str,
IPv4Network: str,
IPv6Address: str,
IPv6Interface: str,
IPv6Network: str,
NameEmail: str,
Path: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
SecretStr: str,
set: list,
UUID: str,
Url: str,
AnyUrl: str,
}
def generate_encoders_by_class_tuples(
type_encoder_map: dict[Any, Callable[[Any], Any]],
) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple)
for type_, encoder in type_encoder_map.items():
encoders_by_class_tuples[encoder] += (type_,)
return encoders_by_class_tuples
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
def jsonable_encoder(
obj: Any,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: dict[Any, Callable[[Any], Any]] | None = None,
sqlalchemy_safe: bool = True,
) -> Any:
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if isinstance(obj, BaseModel):
obj_dict = _model_dump(
obj,
mode="json",
include=None,
exclude=None,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
# Ensure obj is a dataclass instance, not a dataclass type
if not isinstance(obj, type):
obj_dict = dataclasses.asdict(obj)
return jsonable_encoder(
obj_dict,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, str | int | float | type(None)):
return obj
if isinstance(obj, Decimal):
return format(obj, "f")
if isinstance(obj, dict):
encoded_dict = {}
for key, value in obj.items():
if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (
value is not None or not exclude_none
):
encoded_key = jsonable_encoder(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
try:
data = dict(obj) # type: ignore
except Exception as e:
errors: list[Exception] = []
errors.append(e)
try:
data = vars(obj) # type: ignore
except Exception as e:
errors.append(e)
raise ValueError(str(errors)) from e
return jsonable_encoder(
data,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)

View File

@ -3,8 +3,8 @@ from typing import Any
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
class NodeEventBase(BaseModel):

View File

@ -3,10 +3,10 @@ from datetime import datetime
from pydantic import Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities.pause_reason import PauseReason
from dify_graph.file import File
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeRunResult
from .base import NodeEventBase

View File

@ -13,9 +13,6 @@ from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
@ -32,6 +29,9 @@ from dify_graph.enums import (
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
NodeEventBase,

View File

@ -1,4 +1,4 @@
from core.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.runtime import GraphRuntimeState

View File

@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.enums import (
NodeExecutionType,
@ -20,6 +19,7 @@ from dify_graph.graph_events import (
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
)
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import (
IterationFailedEvent,
IterationNextEvent,

View File

@ -3,14 +3,14 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.entities import GraphInitParams
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node

View File

@ -3,8 +3,8 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base import BaseNodeData
from dify_graph.nodes.base.entities import VariableSelector

View File

@ -2,16 +2,16 @@ from collections.abc import Sequence
from typing import cast
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessageRole
from core.model_runtime.entities.message_entities import (
from dify_graph.file.models import File
from dify_graph.model_runtime.entities import PromptMessageRole
from dify_graph.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.file.models import File
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment

View File

@ -15,30 +15,6 @@ from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMStructuredOutput,
LLMUsage,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
@ -52,6 +28,30 @@ from dify_graph.enums import (
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
ModelInvokeCompletedEvent,
NodeEventBase,

View File

@ -5,7 +5,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.enums import (
NodeExecutionType,
NodeType,
@ -17,6 +16,7 @@ from dify_graph.graph_events import (
GraphRunFailedEvent,
NodeRunSucceededEvent,
)
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import (
LoopFailedEvent,
LoopNextEvent,

View File

@ -6,20 +6,6 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
@ -30,6 +16,20 @@ from dify_graph.enums import (
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File
from dify_graph.model_runtime.entities import ImagePromptMessageContent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import variable_template_parser
from dify_graph.nodes.base.node import Node

View File

@ -4,9 +4,6 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities import GraphInitParams
@ -16,6 +13,9 @@ from dify_graph.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node

View File

@ -5,7 +5,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.model_runtime.entities.llm_entities import LLMUsage
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
@ -18,6 +17,7 @@ from dify_graph.enums import (
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser

View File

@ -2,7 +2,7 @@ from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from core.model_runtime.entities import LLMUsage
from dify_graph.model_runtime.entities import LLMUsage
from dify_graph.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
from dify_graph.nodes.llm.entities import ModelConfig

View File

@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Protocol
from pydantic import BaseModel, Field
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.enums import NodeExecutionType, NodeState, NodeType
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.runtime.variable_pool import VariablePool
if TYPE_CHECKING:

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.system_variable import SystemVariableReadOnlyView
from dify_graph.variables.segments import Segment

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.system_variable import SystemVariableReadOnlyView
from dify_graph.variables.segments import Segment