mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 00:18:03 +08:00
refactor(api): move model_runtime into dify_graph (#32858)
This commit is contained in:
51
api/dify_graph/model_runtime/README.md
Normal file
51
api/dify_graph/model_runtime/README.md
Normal file
@ -0,0 +1,51 @@
|
||||
# Model Runtime
|
||||
|
||||
This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
|
||||
|
||||
- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
|
||||
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
|
||||
|
||||
## Features
|
||||
|
||||
- Supports capability invocation for 6 types of models
|
||||
|
||||
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
|
||||
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
|
||||
- `Rerank Model` - Segment Rerank capability
|
||||
- `Speech-to-text Model` - Speech to text capability
|
||||
- `Text-to-speech Model` - Text to speech capability
|
||||
- `Moderation` - Moderation capability
|
||||
|
||||
- Model provider display
|
||||
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
|
||||
|
||||
- Selectable model list display
|
||||
|
||||
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
|
||||
|
||||
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
|
||||
|
||||
- Provider/model credential authentication
|
||||
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
|
||||
|
||||
## Structure
|
||||
|
||||
Model Runtime is divided into three layers:
|
||||
|
||||
- The outermost layer is the factory method
|
||||
|
||||
It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
|
||||
|
||||
- The second layer is the provider layer
|
||||
|
||||
It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
|
||||
|
||||
- The bottom layer is the model layer
|
||||
|
||||
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).
|
||||
64
api/dify_graph/model_runtime/README_CN.md
Normal file
64
api/dify_graph/model_runtime/README_CN.md
Normal file
@ -0,0 +1,64 @@
|
||||
# Model Runtime
|
||||
|
||||
该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。
|
||||
|
||||
- 一方面将模型和上下游解耦,方便开发者对模型横向扩展,
|
||||
- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。
|
||||
|
||||
## 功能介绍
|
||||
|
||||
- 支持 6 种模型类型的能力调用
|
||||
|
||||
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
|
||||
- `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力
|
||||
- `Rerank Model` - 分段 Rerank 能力
|
||||
- `Speech-to-text Model` - 语音转文本能力
|
||||
- `Text-to-speech Model` - 文本转语音能力
|
||||
- `Moderation` - Moderation 能力
|
||||
|
||||
- 模型供应商展示
|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
|
||||
|
||||
- 供应商/模型凭据鉴权
|
||||
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
|
||||
|
||||
## 结构
|
||||
|
||||
Model Runtime 分三层:
|
||||
|
||||
- 最外层为工厂方法
|
||||
|
||||
提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。
|
||||
|
||||
- 第二层为供应商层
|
||||
|
||||
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
|
||||
|
||||
对于供应商/模型凭据,有两种情况
|
||||
|
||||
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||
|
||||
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
||||
|
||||
- 最底层为模型层
|
||||
|
||||
提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。
|
||||
|
||||
在这里我们需要先区分模型参数与模型凭据。
|
||||
|
||||
- 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。
|
||||
|
||||
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
|
||||
|
||||
## 文档
|
||||
|
||||
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。
|
||||
0
api/dify_graph/model_runtime/__init__.py
Normal file
0
api/dify_graph/model_runtime/__init__.py
Normal file
0
api/dify_graph/model_runtime/callbacks/__init__.py
Normal file
0
api/dify_graph/model_runtime/callbacks/__init__.py
Normal file
151
api/dify_graph/model_runtime/callbacks/base_callback.py
Normal file
151
api/dify_graph/model_runtime/callbacks/base_callback.py
Normal file
@ -0,0 +1,151 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
class Callback(ABC):
|
||||
"""
|
||||
Base class for callbacks.
|
||||
Only for LLM.
|
||||
"""
|
||||
|
||||
raise_error: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def print_text(self, text: str, color: str | None = None, end: str = ""):
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
170
api/dify_graph/model_runtime/callbacks/logging_callback.py
Normal file
170
api/dify_graph/model_runtime/callbacks/logging_callback.py
Normal file
@ -0,0 +1,170 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggingCallback(Callback):
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_before_invoke]\n", color="blue")
|
||||
self.print_text(f"Model: {model}\n", color="blue")
|
||||
self.print_text("Parameters:\n", color="blue")
|
||||
for key, value in model_parameters.items():
|
||||
self.print_text(f"\t{key}: {value}\n", color="blue")
|
||||
|
||||
if stop:
|
||||
self.print_text(f"\tstop: {stop}\n", color="blue")
|
||||
|
||||
if tools:
|
||||
self.print_text("\tTools:\n", color="blue")
|
||||
for tool in tools:
|
||||
self.print_text(f"\t\t{tool.name}\n", color="blue")
|
||||
|
||||
self.print_text(f"Stream: {stream}\n", color="blue")
|
||||
|
||||
if user:
|
||||
self.print_text(f"User: {user}\n", color="blue")
|
||||
|
||||
self.print_text("Prompt messages:\n", color="blue")
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.name:
|
||||
self.print_text(f"\tname: {prompt_message.name}\n", color="blue")
|
||||
|
||||
self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue")
|
||||
self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue")
|
||||
|
||||
if stream:
|
||||
self.print_text("\n[on_llm_new_chunk]")
|
||||
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
sys.stdout.write(cast(str, chunk.delta.message.content))
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_after_invoke]\n", color="yellow")
|
||||
self.print_text(f"Content: {result.message.content}\n", color="yellow")
|
||||
|
||||
if result.message.tool_calls:
|
||||
self.print_text("Tool calls:\n", color="yellow")
|
||||
for tool_call in result.message.tool_calls:
|
||||
self.print_text(f"\t{tool_call.id}\n", color="yellow")
|
||||
self.print_text(f"\t{tool_call.function.name}\n", color="yellow")
|
||||
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow")
|
||||
|
||||
self.print_text(f"Model: {result.model}\n", color="yellow")
|
||||
self.print_text(f"Usage: {result.usage}\n", color="yellow")
|
||||
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow")
|
||||
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_invoke_error]\n", color="red")
|
||||
logger.exception(ex)
|
||||
43
api/dify_graph/model_runtime/entities/__init__.py
Normal file
43
api/dify_graph/model_runtime/entities/__init__.py
Normal file
@ -0,0 +1,43 @@
|
||||
from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from .model_entities import ModelPropertyKey
|
||||
|
||||
__all__ = [
|
||||
"AssistantPromptMessage",
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
"ImagePromptMessageContent",
|
||||
"LLMMode",
|
||||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"MultiModalPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageContentType",
|
||||
"PromptMessageRole",
|
||||
"PromptMessageTool",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
"ToolPromptMessage",
|
||||
"UserPromptMessage",
|
||||
"VideoPromptMessageContent",
|
||||
]
|
||||
16
api/dify_graph/model_runtime/entities/common_entities.py
Normal file
16
api/dify_graph/model_runtime/entities/common_entities.py
Normal file
@ -0,0 +1,16 @@
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
zh_Hans: str | None = None
|
||||
en_US: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.zh_Hans:
|
||||
self.zh_Hans = self.en_US
|
||||
return self
|
||||
130
api/dify_graph/model_runtime/entities/defaults.py
Normal file
130
api/dify_graph/model_runtime/entities/defaults.py
Normal file
@ -0,0 +1,130 @@
|
||||
from dify_graph.model_runtime.entities.model_entities import DefaultParameterName
|
||||
|
||||
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
DefaultParameterName.TEMPERATURE: {
|
||||
"label": {
|
||||
"en_US": "Temperature",
|
||||
"zh_Hans": "温度",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls randomness. Lower temperature results in less random completions."
|
||||
" As the temperature approaches zero, the model will become deterministic and repetitive."
|
||||
" Higher temperature results in more random completions.",
|
||||
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。"
|
||||
"较高的温度会导致更多的随机完成。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.TOP_P: {
|
||||
"label": {
|
||||
"en_US": "Top P",
|
||||
"zh_Hans": "Top P",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
|
||||
" are considered.",
|
||||
"zh_Hans": "通过核心采样控制多样性:0.5 表示考虑了一半的所有可能性加权选项。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.TOP_K: {
|
||||
"label": {
|
||||
"en_US": "Top K",
|
||||
"zh_Hans": "Top K",
|
||||
},
|
||||
"type": "int",
|
||||
"help": {
|
||||
"en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.",
|
||||
"zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 50,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"precision": 0,
|
||||
},
|
||||
DefaultParameterName.PRESENCE_PENALTY: {
|
||||
"label": {
|
||||
"en_US": "Presence Penalty",
|
||||
"zh_Hans": "存在惩罚",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
|
||||
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.FREQUENCY_PENALTY: {
|
||||
"label": {
|
||||
"en_US": "Frequency Penalty",
|
||||
"zh_Hans": "频率惩罚",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
|
||||
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.MAX_TOKENS: {
|
||||
"label": {
|
||||
"en_US": "Max Tokens",
|
||||
"zh_Hans": "最大 Token 数",
|
||||
},
|
||||
"type": "int",
|
||||
"help": {
|
||||
"en_US": "Specifies the upper limit on the length of generated results."
|
||||
" If the generated results are truncated, you can increase this parameter.",
|
||||
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 64,
|
||||
"min": 1,
|
||||
"max": 2048,
|
||||
"precision": 0,
|
||||
},
|
||||
DefaultParameterName.RESPONSE_FORMAT: {
|
||||
"label": {
|
||||
"en_US": "Response Format",
|
||||
"zh_Hans": "回复格式",
|
||||
},
|
||||
"type": "string",
|
||||
"help": {
|
||||
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
|
||||
" such as JSON, XML, etc.",
|
||||
"zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等",
|
||||
},
|
||||
"required": False,
|
||||
"options": ["JSON", "XML"],
|
||||
},
|
||||
DefaultParameterName.JSON_SCHEMA: {
|
||||
"label": {
|
||||
"en_US": "JSON Schema",
|
||||
},
|
||||
"type": "text",
|
||||
"help": {
|
||||
"en_US": "Set a response json schema will ensure LLM to adhere it.",
|
||||
"zh_Hans": "设置返回的 json schema,llm 将按照它返回",
|
||||
},
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
219
api/dify_graph/model_runtime/entities/llm_entities.py
Normal file
219
api/dify_graph/model_runtime/entities/llm_entities.py
Normal file
@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import Any, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelUsage, PriceInfo
|
||||
|
||||
|
||||
class LLMMode(StrEnum):
|
||||
"""
|
||||
Enum class for large language model mode.
|
||||
"""
|
||||
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class LLMUsageMetadata(TypedDict, total=False):
|
||||
"""
|
||||
TypedDict for LLM usage metadata.
|
||||
All fields are optional.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_unit_price: Union[float, str]
|
||||
completion_unit_price: Union[float, str]
|
||||
total_price: Union[float, str]
|
||||
currency: str
|
||||
prompt_price_unit: Union[float, str]
|
||||
completion_price_unit: Union[float, str]
|
||||
prompt_price: Union[float, str]
|
||||
completion_price: Union[float, str]
|
||||
latency: float
|
||||
time_to_first_token: float
|
||||
time_to_generate: float
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for llm usage.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
prompt_unit_price: Decimal
|
||||
prompt_price_unit: Decimal
|
||||
prompt_price: Decimal
|
||||
completion_tokens: int
|
||||
completion_unit_price: Decimal
|
||||
completion_price_unit: Decimal
|
||||
completion_price: Decimal
|
||||
total_tokens: int
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
time_to_first_token: float | None = None
|
||||
time_to_generate: float | None = None
|
||||
|
||||
@classmethod
|
||||
def empty_usage(cls):
|
||||
return cls(
|
||||
prompt_tokens=0,
|
||||
prompt_unit_price=Decimal("0.0"),
|
||||
prompt_price_unit=Decimal("0.0"),
|
||||
prompt_price=Decimal("0.0"),
|
||||
completion_tokens=0,
|
||||
completion_unit_price=Decimal("0.0"),
|
||||
completion_price_unit=Decimal("0.0"),
|
||||
completion_price=Decimal("0.0"),
|
||||
total_tokens=0,
|
||||
total_price=Decimal("0.0"),
|
||||
currency="USD",
|
||||
latency=0.0,
|
||||
time_to_first_token=None,
|
||||
time_to_generate=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
|
||||
"""
|
||||
Create LLMUsage instance from metadata dictionary with default values.
|
||||
|
||||
Args:
|
||||
metadata: TypedDict containing usage metadata
|
||||
|
||||
Returns:
|
||||
LLMUsage instance with values from metadata or defaults
|
||||
"""
|
||||
prompt_tokens = metadata.get("prompt_tokens", 0)
|
||||
completion_tokens = metadata.get("completion_tokens", 0)
|
||||
total_tokens = metadata.get("total_tokens", 0)
|
||||
|
||||
# If total_tokens is not provided but prompt and completion tokens are,
|
||||
# calculate total_tokens
|
||||
if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return cls(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
|
||||
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
|
||||
total_price=Decimal(str(metadata.get("total_price", 0))),
|
||||
currency=metadata.get("currency", "USD"),
|
||||
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
|
||||
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
|
||||
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
|
||||
completion_price=Decimal(str(metadata.get("completion_price", 0))),
|
||||
latency=metadata.get("latency", 0.0),
|
||||
time_to_first_token=metadata.get("time_to_first_token"),
|
||||
time_to_generate=metadata.get("time_to_generate"),
|
||||
)
|
||||
|
||||
def plus(self, other: LLMUsage) -> LLMUsage:
|
||||
"""
|
||||
Add two LLMUsage instances together.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
if self.total_tokens == 0:
|
||||
return other
|
||||
else:
|
||||
return LLMUsage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
prompt_unit_price=other.prompt_unit_price,
|
||||
prompt_price_unit=other.prompt_price_unit,
|
||||
prompt_price=self.prompt_price + other.prompt_price,
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
completion_unit_price=other.completion_unit_price,
|
||||
completion_price_unit=other.completion_price_unit,
|
||||
completion_price=self.completion_price + other.completion_price,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency,
|
||||
time_to_first_token=other.time_to_first_token,
|
||||
time_to_generate=other.time_to_generate,
|
||||
)
|
||||
|
||||
def __add__(self, other: LLMUsage) -> LLMUsage:
|
||||
"""
|
||||
Overload the + operator to add two LLMUsage instances.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
return self.plus(other)
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
Model class for llm result.
|
||||
"""
|
||||
|
||||
id: str | None = None
|
||||
model: str
|
||||
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage
|
||||
system_fingerprint: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class LLMStructuredOutput(BaseModel):
|
||||
"""
|
||||
Model class for llm structured output.
|
||||
"""
|
||||
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
"""
|
||||
|
||||
index: int
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
|
||||
model: str
|
||||
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||
system_fingerprint: str | None = None
|
||||
delta: LLMResultChunkDelta
|
||||
|
||||
|
||||
class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result chunk with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class NumTokensResult(PriceInfo):
|
||||
"""
|
||||
Model class for number of tokens result.
|
||||
"""
|
||||
|
||||
tokens: int
|
||||
282
api/dify_graph/model_runtime/entities/message_entities.py
Normal file
282
api/dify_graph/model_runtime/entities/message_entities.py
Normal file
@ -0,0 +1,282 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
|
||||
SYSTEM = auto()
|
||||
USER = auto()
|
||||
ASSISTANT = auto()
|
||||
TOOL = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> PromptMessageRole:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid prompt message type value {value}")
|
||||
|
||||
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
Model class for prompt message tool.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
|
||||
|
||||
class PromptMessageFunction(BaseModel):
|
||||
"""
|
||||
Model class for prompt message function.
|
||||
"""
|
||||
|
||||
type: str = "function"
|
||||
function: PromptMessageTool
|
||||
|
||||
|
||||
class PromptMessageContentType(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
AUDIO = auto()
|
||||
VIDEO = auto()
|
||||
DOCUMENT = auto()
|
||||
|
||||
|
||||
class PromptMessageContent(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
|
||||
data: str
|
||||
|
||||
|
||||
class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
format: str = Field(default=..., description="the format of multi-modal file")
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
filename: str = Field(default="", description="the filename of multi-modal file")
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
|
||||
class DETAIL(StrEnum):
|
||||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
|
||||
|
||||
|
||||
PromptMessageContentUnionTypes = Annotated[
|
||||
Union[
|
||||
TextPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
AudioPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
|
||||
PromptMessageContentType.TEXT: TextPromptMessageContent,
|
||||
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
|
||||
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
|
||||
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
|
||||
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: str | list[PromptMessageContentUnionTypes] | None = None
|
||||
name: str | None = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return not self.content
|
||||
|
||||
def get_text_content(self) -> str:
|
||||
"""
|
||||
Get text content from prompt message.
|
||||
|
||||
:return: Text content as string, empty string if no text content
|
||||
"""
|
||||
if isinstance(self.content, str):
|
||||
return self.content
|
||||
elif isinstance(self.content, list):
|
||||
text_parts = []
|
||||
for item in self.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return "".join(text_parts)
|
||||
else:
|
||||
return ""
|
||||
|
||||
@field_validator("content", mode="before")
|
||||
@classmethod
|
||||
def validate_content(cls, v):
|
||||
if isinstance(v, list):
|
||||
prompts = []
|
||||
for prompt in v:
|
||||
if isinstance(prompt, PromptMessageContent):
|
||||
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
|
||||
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
|
||||
elif isinstance(prompt, dict):
|
||||
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
|
||||
else:
|
||||
raise ValueError(f"invalid prompt message {prompt}")
|
||||
prompts.append(prompt)
|
||||
return prompts
|
||||
return v
|
||||
|
||||
@field_serializer("content")
|
||||
def serialize_content(
|
||||
self, content: Union[str, Sequence[PromptMessageContent]] | None
|
||||
) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None:
|
||||
if content is None or isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
return [item.model_dump() if hasattr(item, "model_dump") else item for item in content]
|
||||
return content
|
||||
|
||||
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for user prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.USER
|
||||
|
||||
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for assistant prompt message.
|
||||
"""
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call.
|
||||
"""
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
id: str
|
||||
type: str
|
||||
function: ToolCallFunction
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return super().is_empty() and not self.tool_calls
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for system prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||
|
||||
|
||||
class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for tool prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
242
api/dify_graph/model_runtime/entities/model_entities.py
Normal file
242
api/dify_graph/model_runtime/entities/model_entities.py
Normal file
@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ModelType(StrEnum):
|
||||
"""
|
||||
Enum class for model type.
|
||||
"""
|
||||
|
||||
LLM = auto()
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = auto()
|
||||
SPEECH2TEXT = auto()
|
||||
MODERATION = auto()
|
||||
TTS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> ModelType:
|
||||
"""
|
||||
Get model type from origin model type.
|
||||
|
||||
:return: model type
|
||||
"""
|
||||
if origin_model_type in {"text-generation", cls.LLM}:
|
||||
return cls.LLM
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type in {"reranking", cls.RERANK}:
|
||||
return cls.RERANK
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type in {"tts", cls.TTS}:
|
||||
return cls.TTS
|
||||
elif origin_model_type == cls.MODERATION:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
raise ValueError(f"invalid origin model type {origin_model_type}")
|
||||
|
||||
def to_origin_model_type(self) -> str:
|
||||
"""
|
||||
Get origin model type from model type.
|
||||
|
||||
:return: origin model type
|
||||
"""
|
||||
if self == self.LLM:
|
||||
return "text-generation"
|
||||
elif self == self.TEXT_EMBEDDING:
|
||||
return "embeddings"
|
||||
elif self == self.RERANK:
|
||||
return "reranking"
|
||||
elif self == self.SPEECH2TEXT:
|
||||
return "speech2text"
|
||||
elif self == self.TTS:
|
||||
return "tts"
|
||||
elif self == self.MODERATION:
|
||||
return "moderation"
|
||||
else:
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
||||
class FetchFrom(StrEnum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
"""
|
||||
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class ModelFeature(StrEnum):
|
||||
"""
|
||||
Enum class for llm feature.
|
||||
"""
|
||||
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
STRUCTURED_OUTPUT = "structured-output"
|
||||
|
||||
|
||||
class DefaultParameterName(StrEnum):
|
||||
"""
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
|
||||
TEMPERATURE = auto()
|
||||
TOP_P = auto()
|
||||
TOP_K = auto()
|
||||
PRESENCE_PENALTY = auto()
|
||||
FREQUENCY_PENALTY = auto()
|
||||
MAX_TOKENS = auto()
|
||||
RESPONSE_FORMAT = auto()
|
||||
JSON_SCHEMA = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> DefaultParameterName:
|
||||
"""
|
||||
Get parameter name from value.
|
||||
|
||||
:param value: parameter value
|
||||
:return: parameter name
|
||||
"""
|
||||
for name in cls:
|
||||
if name.value == value:
|
||||
return name
|
||||
raise ValueError(f"invalid parameter name {value}")
|
||||
|
||||
|
||||
class ParameterType(StrEnum):
|
||||
"""
|
||||
Enum class for parameter type.
|
||||
"""
|
||||
|
||||
FLOAT = auto()
|
||||
INT = auto()
|
||||
STRING = auto()
|
||||
BOOLEAN = auto()
|
||||
TEXT = auto()
|
||||
|
||||
|
||||
class ModelPropertyKey(StrEnum):
|
||||
"""
|
||||
Enum class for model property key.
|
||||
"""
|
||||
|
||||
MODE = auto()
|
||||
CONTEXT_SIZE = auto()
|
||||
MAX_CHUNKS = auto()
|
||||
FILE_UPLOAD_LIMIT = auto()
|
||||
SUPPORTED_FILE_EXTENSIONS = auto()
|
||||
MAX_CHARACTERS_PER_CHUNK = auto()
|
||||
DEFAULT_VOICE = auto()
|
||||
VOICES = auto()
|
||||
WORD_LIMIT = auto()
|
||||
AUDIO_TYPE = auto()
|
||||
MAX_WORKERS = auto()
|
||||
|
||||
|
||||
class ProviderModel(BaseModel):
|
||||
"""
|
||||
Model class for provider model.
|
||||
"""
|
||||
|
||||
model: str
|
||||
label: I18nObject
|
||||
model_type: ModelType
|
||||
features: list[ModelFeature] | None = None
|
||||
fetch_from: FetchFrom
|
||||
model_properties: dict[ModelPropertyKey, Any]
|
||||
deprecated: bool = False
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def support_structure_output(self) -> bool:
|
||||
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
|
||||
|
||||
|
||||
class ParameterRule(BaseModel):
|
||||
"""
|
||||
Model class for parameter rule.
|
||||
"""
|
||||
|
||||
name: str
|
||||
use_template: str | None = None
|
||||
label: I18nObject
|
||||
type: ParameterType
|
||||
help: I18nObject | None = None
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
precision: int | None = None
|
||||
options: list[str] = []
|
||||
|
||||
|
||||
class PriceConfig(BaseModel):
|
||||
"""
|
||||
Model class for pricing info.
|
||||
"""
|
||||
|
||||
input: Decimal
|
||||
output: Decimal | None = None
|
||||
unit: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class AIModelEntity(ProviderModel):
|
||||
"""
|
||||
Model class for AI model.
|
||||
"""
|
||||
|
||||
parameter_rules: list[ParameterRule] = []
|
||||
pricing: PriceConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model(self):
|
||||
supported_schema_keys = ["json_schema"]
|
||||
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
||||
if not schema_key:
|
||||
return self
|
||||
if self.features is None:
|
||||
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||
else:
|
||||
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
|
||||
return self
|
||||
|
||||
|
||||
class ModelUsage(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class PriceType(StrEnum):
|
||||
"""
|
||||
Enum class for price type.
|
||||
"""
|
||||
|
||||
INPUT = auto()
|
||||
OUTPUT = auto()
|
||||
|
||||
|
||||
class PriceInfo(BaseModel):
|
||||
"""
|
||||
Model class for price info.
|
||||
"""
|
||||
|
||||
unit_price: Decimal
|
||||
unit: Decimal
|
||||
total_amount: Decimal
|
||||
currency: str
|
||||
169
api/dify_graph/model_runtime/entities/provider_entities.py
Normal file
169
api/dify_graph/model_runtime/entities/provider_entities.py
Normal file
@ -0,0 +1,169 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
class ConfigurateMethod(StrEnum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class FormType(StrEnum):
|
||||
"""
|
||||
Enum class for form type.
|
||||
"""
|
||||
|
||||
TEXT_INPUT = "text-input"
|
||||
SECRET_INPUT = "secret-input"
|
||||
SELECT = auto()
|
||||
RADIO = auto()
|
||||
SWITCH = auto()
|
||||
|
||||
|
||||
class FormShowOnObject(BaseModel):
|
||||
"""
|
||||
Model class for form show on.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value: str
|
||||
|
||||
|
||||
class FormOption(BaseModel):
|
||||
"""
|
||||
Model class for form option.
|
||||
"""
|
||||
|
||||
label: I18nObject
|
||||
value: str
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.label:
|
||||
self.label = I18nObject(en_US=self.value)
|
||||
return self
|
||||
|
||||
|
||||
class CredentialFormSchema(BaseModel):
|
||||
"""
|
||||
Model class for credential form schema.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
label: I18nObject
|
||||
type: FormType
|
||||
required: bool = True
|
||||
default: str | None = None
|
||||
options: list[FormOption] | None = None
|
||||
placeholder: I18nObject | None = None
|
||||
max_length: int = 0
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
|
||||
class ProviderCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for provider credential schema.
|
||||
"""
|
||||
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
|
||||
class FieldModelSchema(BaseModel):
|
||||
label: I18nObject
|
||||
placeholder: I18nObject | None = None
|
||||
|
||||
|
||||
class ModelCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for model credential schema.
|
||||
"""
|
||||
|
||||
model: FieldModelSchema
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
|
||||
class SimpleProviderEntity(BaseModel):
|
||||
"""
|
||||
Simple model class for provider.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
|
||||
|
||||
class ProviderHelpEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider help.
|
||||
"""
|
||||
|
||||
title: I18nObject
|
||||
url: I18nObject
|
||||
|
||||
|
||||
class ProviderEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
models: list[AIModelEntity] = Field(default_factory=list)
|
||||
provider_credential_schema: ProviderCredentialSchema | None = None
|
||||
model_credential_schema: ModelCredentialSchema | None = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
# position from plugin _position.yaml
|
||||
position: dict[str, list[str]] | None = {}
|
||||
|
||||
@field_validator("models", mode="before")
|
||||
@classmethod
|
||||
def validate_models(cls, v):
|
||||
# returns EmptyList if v is empty
|
||||
if not v:
|
||||
return []
|
||||
return v
|
||||
|
||||
def to_simple_provider(self) -> SimpleProviderEntity:
|
||||
"""
|
||||
Convert to simple provider.
|
||||
|
||||
:return: simple provider
|
||||
"""
|
||||
return SimpleProviderEntity(
|
||||
provider=self.provider,
|
||||
label=self.label,
|
||||
icon_small=self.icon_small,
|
||||
supported_model_types=self.supported_model_types,
|
||||
models=self.models,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""
|
||||
Model class for provider config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
credentials: dict
|
||||
20
api/dify_graph/model_runtime/entities/rerank_entities.py
Normal file
20
api/dify_graph/model_runtime/entities/rerank_entities.py
Normal file
@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
"""
|
||||
|
||||
index: int
|
||||
text: str
|
||||
score: float
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
"""
|
||||
Model class for rerank result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
docs: list[RerankDocument]
|
||||
@ -0,0 +1,39 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelUsage
|
||||
|
||||
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
"""
|
||||
|
||||
tokens: int
|
||||
total_tokens: int
|
||||
unit_price: Decimal
|
||||
price_unit: Decimal
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
|
||||
|
||||
class EmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
class FileEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for file embedding result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
0
api/dify_graph/model_runtime/errors/__init__.py
Normal file
0
api/dify_graph/model_runtime/errors/__init__.py
Normal file
40
api/dify_graph/model_runtime/errors/invoke.py
Normal file
40
api/dify_graph/model_runtime/errors/invoke.py
Normal file
@ -0,0 +1,40 @@
|
||||
class InvokeError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
|
||||
class InvokeConnectionError(InvokeError):
|
||||
"""Raised when the Invoke returns connection error."""
|
||||
|
||||
description = "Connection Error"
|
||||
|
||||
|
||||
class InvokeServerUnavailableError(InvokeError):
|
||||
"""Raised when the Invoke returns server unavailable error."""
|
||||
|
||||
description = "Server Unavailable Error"
|
||||
|
||||
|
||||
class InvokeRateLimitError(InvokeError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
|
||||
|
||||
class InvokeAuthorizationError(InvokeError):
|
||||
"""Raised when the Invoke returns authorization error."""
|
||||
|
||||
description = "Incorrect model credentials provided, please check and try again. "
|
||||
|
||||
|
||||
class InvokeBadRequestError(InvokeError):
|
||||
"""Raised when the Invoke returns bad request."""
|
||||
|
||||
description = "Bad Request Error"
|
||||
6
api/dify_graph/model_runtime/errors/validate.py
Normal file
6
api/dify_graph/model_runtime/errors/validate.py
Normal file
@ -0,0 +1,6 @@
|
||||
class CredentialsValidateFailedError(ValueError):
|
||||
"""
|
||||
Credentials validate failed error
|
||||
"""
|
||||
|
||||
pass
|
||||
3
api/dify_graph/model_runtime/memory/__init__.py
Normal file
3
api/dify_graph/model_runtime/memory/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
|
||||
|
||||
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]
|
||||
18
api/dify_graph/model_runtime/memory/prompt_message_memory.py
Normal file
18
api/dify_graph/model_runtime/memory/prompt_message_memory.py
Normal file
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from dify_graph.model_runtime.entities import PromptMessage
|
||||
|
||||
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
|
||||
|
||||
|
||||
class PromptMessageMemory(Protocol):
|
||||
"""Port for loading memory as prompt messages."""
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Return historical prompt messages constrained by token/message limits."""
|
||||
...
|
||||
286
api/dify_graph/model_runtime/model_providers/__base/ai_model.py
Normal file
286
api/dify_graph/model_runtime/model_providers/__base/ai_model.py
Normal file
@ -0,0 +1,286 @@
|
||||
import decimal
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceInfo,
|
||||
PriceType,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
"""
|
||||
Base class for all models.
|
||||
"""
|
||||
|
||||
tenant_id: str = Field(description="Tenant ID")
|
||||
model_type: ModelType = Field(description="Model type")
|
||||
plugin_id: str = Field(description="Plugin ID")
|
||||
provider_name: str = Field(description="Provider")
|
||||
plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider")
|
||||
started_at: float = Field(description="Invoke start time", default=0)
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [InvokeBadRequestError],
|
||||
PluginDaemonInnerError: [PluginDaemonInnerError],
|
||||
ValueError: [ValueError],
|
||||
}
|
||||
|
||||
def _transform_invoke_error(self, error: Exception) -> Exception:
|
||||
"""
|
||||
Transform invoke error to unified error
|
||||
|
||||
:param error: model invoke error
|
||||
:return: unified error
|
||||
"""
|
||||
for invoke_error, model_errors in self._invoke_error_mapping.items():
|
||||
if isinstance(error, tuple(model_errors)):
|
||||
if invoke_error == InvokeAuthorizationError:
|
||||
return InvokeAuthorizationError(
|
||||
description=(
|
||||
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
|
||||
)
|
||||
)
|
||||
elif isinstance(invoke_error, InvokeError):
|
||||
return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
|
||||
else:
|
||||
return error
|
||||
|
||||
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")
|
||||
|
||||
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
|
||||
"""
|
||||
Get price for given model and tokens
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param price_type: price type
|
||||
:param tokens: number of tokens
|
||||
:return: price info
|
||||
"""
|
||||
# get model schema
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
# get price info from predefined model schema
|
||||
price_config: PriceConfig | None = None
|
||||
if model_schema and model_schema.pricing:
|
||||
price_config = model_schema.pricing
|
||||
|
||||
# get unit price
|
||||
unit_price = None
|
||||
if price_config:
|
||||
if price_type == PriceType.INPUT:
|
||||
unit_price = price_config.input
|
||||
elif price_type == PriceType.OUTPUT and price_config.output is not None:
|
||||
unit_price = price_config.output
|
||||
|
||||
if unit_price is None:
|
||||
return PriceInfo(
|
||||
unit_price=decimal.Decimal("0.0"),
|
||||
unit=decimal.Decimal("0.0"),
|
||||
total_amount=decimal.Decimal("0.0"),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
# calculate total amount
|
||||
if not price_config:
|
||||
raise ValueError(f"Price config not found for model {model}")
|
||||
total_amount = tokens * unit_price * price_config.unit
|
||||
total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
return PriceInfo(
|
||||
unit_price=unit_price,
|
||||
unit=price_config.unit,
|
||||
total_amount=total_amount,
|
||||
currency=price_config.currency,
|
||||
)
|
||||
|
||||
def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema by model name and credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
schema = plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
|
||||
# get customizable model schema
|
||||
schema = self.get_customizable_model_schema(model, credentials)
|
||||
if not schema:
|
||||
return None
|
||||
|
||||
# fill in the template
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in schema.parameter_rules:
|
||||
if parameter_rule.use_template:
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
if not parameter_rule.max and "max" in default_parameter_rule:
|
||||
parameter_rule.max = default_parameter_rule["max"]
|
||||
if not parameter_rule.min and "min" in default_parameter_rule:
|
||||
parameter_rule.min = default_parameter_rule["min"]
|
||||
if not parameter_rule.default and "default" in default_parameter_rule:
|
||||
parameter_rule.default = default_parameter_rule["default"]
|
||||
if not parameter_rule.precision and "precision" in default_parameter_rule:
|
||||
parameter_rule.precision = default_parameter_rule["precision"]
|
||||
if not parameter_rule.required and "required" in default_parameter_rule:
|
||||
parameter_rule.required = default_parameter_rule["required"]
|
||||
if not parameter_rule.help and "help" in default_parameter_rule:
|
||||
parameter_rule.help = I18nObject(
|
||||
en_US=default_parameter_rule["help"]["en_US"],
|
||||
)
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.en_US
|
||||
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.zh_Hans
|
||||
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
|
||||
"zh_Hans", default_parameter_rule["help"]["en_US"]
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
new_parameter_rules.append(parameter_rule)
|
||||
|
||||
schema.parameter_rules = new_parameter_rules
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName):
|
||||
"""
|
||||
Get default parameter rule for given name
|
||||
|
||||
:param name: parameter name
|
||||
:return: parameter rule
|
||||
"""
|
||||
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
||||
|
||||
if not default_parameter_rule:
|
||||
raise Exception(f"Invalid model parameter rule name {name}")
|
||||
|
||||
return default_parameter_rule
|
||||
@ -0,0 +1,668 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Iterator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageTool,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
ModelType,
|
||||
PriceType,
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _gen_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||
|
||||
|
||||
def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None:
|
||||
if not callbacks:
|
||||
return
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
invoke(callback)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise
|
||||
logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e)
|
||||
|
||||
|
||||
def _get_or_create_tool_call(
|
||||
existing_tools_calls: list[AssistantPromptMessage.ToolCall],
|
||||
tool_call_id: str,
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Get or create a tool call by ID.
|
||||
|
||||
If `tool_call_id` is empty, returns the most recently created tool call.
|
||||
"""
|
||||
if not tool_call_id:
|
||||
if not existing_tools_calls:
|
||||
raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta")
|
||||
return existing_tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
existing_tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
|
||||
def _merge_tool_call_delta(
|
||||
tool_call: AssistantPromptMessage.ToolCall,
|
||||
delta: AssistantPromptMessage.ToolCall,
|
||||
) -> None:
|
||||
if delta.id:
|
||||
tool_call.id = delta.id
|
||||
if delta.type:
|
||||
tool_call.type = delta.type
|
||||
if delta.function.name:
|
||||
tool_call.function.name = delta.function.name
|
||||
if delta.function.arguments:
|
||||
tool_call.function.arguments += delta.function.arguments
|
||||
|
||||
|
||||
def _build_llm_result_from_chunks(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
chunks: Iterator[LLMResultChunk],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Build a single `LLMResult` by accumulating all returned chunks.
|
||||
|
||||
Some models only support streaming output (e.g. Qwen3 open-source edition)
|
||||
and the plugin side may still implement the response via a chunked stream,
|
||||
so all chunks must be consumed and concatenated into a single ``LLMResult``.
|
||||
|
||||
The ``usage`` is taken from the last chunk that carries it, which is the
|
||||
typical convention for streaming responses (the final chunk contains the
|
||||
aggregated token counts).
|
||||
"""
|
||||
content = ""
|
||||
content_list: list[PromptMessageContentUnionTypes] = []
|
||||
usage = LLMUsage.empty_usage()
|
||||
system_fingerprint: str | None = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
try:
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content += chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
if chunk.system_fingerprint:
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
except Exception:
|
||||
logger.exception("Error while consuming non-stream plugin chunk iterator.")
|
||||
raise
|
||||
finally:
|
||||
# Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections).
|
||||
close = getattr(chunks, "close", None)
|
||||
if callable(close):
|
||||
close()
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=content or content_list,
|
||||
tool_calls=tools_calls,
|
||||
),
|
||||
usage=usage,
|
||||
system_fingerprint=system_fingerprint,
|
||||
)
|
||||
|
||||
|
||||
def _invoke_llm_via_plugin(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
model_parameters: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_llm(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_non_stream_plugin_result(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
result: Union[LLMResult, Iterator[LLMResultChunk]],
|
||||
) -> LLMResult:
|
||||
if isinstance(result, LLMResult):
|
||||
return result
|
||||
return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result)
|
||||
|
||||
|
||||
def _increase_tool_call(
|
||||
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||
):
|
||||
"""
|
||||
Merge incremental tool call updates into existing tool calls.
|
||||
|
||||
:param new_tool_calls: List of new tool call deltas to be merged.
|
||||
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
|
||||
"""
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# generate ID for tool calls with function name but no ID to track them
|
||||
if new_tool_call.function.name and not new_tool_call.id:
|
||||
new_tool_call.id = _gen_tool_call_id()
|
||||
|
||||
tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id)
|
||||
_merge_tool_call_delta(tool_call, new_tool_call)
|
||||
|
||||
|
||||
class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
Model class for large language model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# validate and filter model parameters
|
||||
if model_parameters is None:
|
||||
model_parameters = {}
|
||||
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
if dify_config.DEBUG:
|
||||
callbacks.append(LoggingCallback())
|
||||
|
||||
# trigger before invoke callbacks
|
||||
self._trigger_before_invoke_callbacks(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
result = _invoke_llm_via_plugin(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model=model, prompt_messages=prompt_messages, result=result
|
||||
)
|
||||
except Exception as e:
|
||||
self._trigger_invoke_error_callbacks(
|
||||
model=model,
|
||||
ex=e,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# TODO
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
if stream and not isinstance(result, LLMResult):
|
||||
return self._invoke_result_generator(
|
||||
model=model,
|
||||
result=result,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
elif isinstance(result, LLMResult):
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=result,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# Following https://github.com/langgenius/dify/issues/17799,
|
||||
# we removed the prompt_messages from the chunk on the plugin daemon side.
|
||||
# To ensure compatibility, we add the prompt_messages back here.
|
||||
result.prompt_messages = prompt_messages
|
||||
return result
|
||||
raise NotImplementedError("unsupported invoke result type", type(result))
|
||||
|
||||
def _invoke_result_generator(
|
||||
self,
|
||||
model: str,
|
||||
result: Generator[LLMResultChunk, None, None],
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Invoke result generator
|
||||
|
||||
:param result: result generator
|
||||
:return: result generator
|
||||
"""
|
||||
callbacks = callbacks or []
|
||||
message_content: list[PromptMessageContentUnionTypes] = []
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
|
||||
def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
|
||||
if not content:
|
||||
return
|
||||
if isinstance(content, list):
|
||||
message_content.extend(content)
|
||||
return
|
||||
if isinstance(content, str):
|
||||
message_content.append(TextPromptMessageContent(data=content))
|
||||
return
|
||||
|
||||
try:
|
||||
for chunk in result:
|
||||
# Following https://github.com/langgenius/dify/issues/17799,
|
||||
# we removed the prompt_messages from the chunk on the plugin daemon side.
|
||||
# To ensure compatibility, we add the prompt_messages back here.
|
||||
chunk.prompt_messages = prompt_messages
|
||||
yield chunk
|
||||
|
||||
self._trigger_new_chunk_callbacks(
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
_update_message_content(chunk.delta.message.content)
|
||||
|
||||
real_model = chunk.model
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
|
||||
if chunk.system_fingerprint:
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=message_content)
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=LLMResult(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_message,
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
)
|
||||
return 0
|
||||
|
||||
def calc_response_usage(
|
||||
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_tokens: prompt tokens
|
||||
:param completion_tokens: completion tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get prompt price info
|
||||
prompt_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
# get completion price info
|
||||
completion_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=prompt_price_info.unit_price,
|
||||
prompt_price_unit=prompt_price_info.unit,
|
||||
prompt_price=prompt_price_info.total_amount,
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=completion_price_info.unit_price,
|
||||
completion_price_unit=completion_price_info.unit,
|
||||
completion_price=completion_price_info.total_amount,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
||||
currency=prompt_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def _trigger_before_invoke_callbacks(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger before invoke callbacks
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_before_invoke",
|
||||
invoke=lambda callback: callback.on_before_invoke(
|
||||
llm_instance=self,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def _trigger_new_chunk_callbacks(
|
||||
self,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger new chunk callbacks
|
||||
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_new_chunk",
|
||||
invoke=lambda callback: callback.on_new_chunk(
|
||||
llm_instance=self,
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def _trigger_after_invoke_callbacks(
|
||||
self,
|
||||
model: str,
|
||||
result: LLMResult,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger after invoke callbacks
|
||||
|
||||
:param model: model name
|
||||
:param result: result
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_after_invoke",
|
||||
invoke=lambda callback: callback.on_after_invoke(
|
||||
llm_instance=self,
|
||||
result=result,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def _trigger_invoke_error_callbacks(
|
||||
self,
|
||||
model: str,
|
||||
ex: Exception,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger invoke error callbacks
|
||||
|
||||
:param model: model name
|
||||
:param ex: exception
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_invoke_error",
|
||||
invoke=lambda callback: callback.on_invoke_error(
|
||||
llm_instance=self,
|
||||
ex=ex,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
@ -0,0 +1,45 @@
|
||||
import time
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class ModerationModel(AIModel):
|
||||
"""
|
||||
Model class for moderation model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
@ -0,0 +1,92 @@
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class RerankModel(AIModel):
|
||||
"""
|
||||
Base Model class for rerank model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.RERANK
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke multimodal rerank model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_multimodal_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
@ -0,0 +1,43 @@
|
||||
from typing import IO
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class Speech2TextModel(AIModel):
|
||||
"""
|
||||
Model class for speech2text model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
|
||||
"""
|
||||
Invoke speech to text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
@ -0,0 +1,121 @@
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class TextEmbeddingModel(AIModel):
|
||||
"""
|
||||
Model class for text embedding model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str] | None = None,
|
||||
multimodel_documents: list[dict] | None = None,
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param files: files to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
try:
|
||||
plugin_model_manager = PluginModelClient()
|
||||
if texts:
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
if multimodel_documents:
|
||||
return plugin_model_manager.invoke_multimodal_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
raise ValueError("No texts or files provided")
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def _get_context_size(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get context size for given embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: context size
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
|
||||
content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
return content_size
|
||||
|
||||
return 1000
|
||||
|
||||
def _get_max_chunks(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get max chunks for given embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: max chunks
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
||||
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
return max_chunks
|
||||
|
||||
return 1
|
||||
@ -0,0 +1,53 @@
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tokenizer: Any | None = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||
"""
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text) # type: ignore
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
||||
#
|
||||
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
# result = future.result()
|
||||
# return cast(int, result)
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder():
|
||||
global _tokenizer, _lock
|
||||
if _tokenizer is not None:
|
||||
return _tokenizer
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
# Try to use tiktoken to get the tokenizer because it is faster
|
||||
#
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
_tokenizer = tiktoken.get_encoding("gpt2")
|
||||
except Exception:
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
||||
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
|
||||
|
||||
return _tokenizer
|
||||
@ -0,0 +1,79 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for TTS model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: str | None = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None):
|
||||
"""
|
||||
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
|
||||
|
||||
:param language: The language for which the voices are requested.
|
||||
:param model: The name of the TTS model.
|
||||
:param credentials: The credentials required to access the TTS model.
|
||||
:return: A list of voices supported by the TTS model.
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
)
|
||||
43
api/dify_graph/model_runtime/model_providers/_position.yaml
Normal file
43
api/dify_graph/model_runtime/model_providers/_position.yaml
Normal file
@ -0,0 +1,43 @@
|
||||
- openai
|
||||
- deepseek
|
||||
- anthropic
|
||||
- azure_openai
|
||||
- google
|
||||
- vertex_ai
|
||||
- nvidia
|
||||
- nvidia_nim
|
||||
- cohere
|
||||
- upstage
|
||||
- bedrock
|
||||
- togetherai
|
||||
- openrouter
|
||||
- ollama
|
||||
- mistralai
|
||||
- groq
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- xinference
|
||||
- triton_inference_server
|
||||
- zhipuai
|
||||
- baichuan
|
||||
- spark
|
||||
- minimax
|
||||
- tongyi
|
||||
- wenxin
|
||||
- moonshot
|
||||
- tencent
|
||||
- jina
|
||||
- chatglm
|
||||
- yi
|
||||
- openllm
|
||||
- localai
|
||||
- volcengine_maas
|
||||
- openai_api_compatible
|
||||
- hunyuan
|
||||
- siliconflow
|
||||
- perfxcloud
|
||||
- zhinao
|
||||
- fireworks
|
||||
- mixedbread
|
||||
- nomic
|
||||
- voyage
|
||||
@ -0,0 +1,386 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import (
|
||||
ProviderCredentialSchemaValidator,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
def __init__(self, tenant_id: str):
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_model_manager = PluginModelClient()
|
||||
|
||||
def get_providers(self) -> Sequence[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
:return: list of providers
|
||||
"""
|
||||
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
|
||||
# The plugin server should return providers in the desired order
|
||||
plugin_providers = self.get_plugin_model_providers()
|
||||
return [provider.declaration for provider in plugin_providers]
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Get all plugin model providers
|
||||
:return: list of plugin model providers
|
||||
"""
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.plugin_model_providers.get()
|
||||
except LookupError:
|
||||
contexts.plugin_model_providers.set(None)
|
||||
contexts.plugin_model_providers_lock.set(Lock())
|
||||
|
||||
with contexts.plugin_model_providers_lock.get():
|
||||
plugin_model_providers = contexts.plugin_model_providers.get()
|
||||
if plugin_model_providers is not None:
|
||||
return plugin_model_providers
|
||||
|
||||
plugin_model_providers = []
|
||||
contexts.plugin_model_providers.set(plugin_model_providers)
|
||||
|
||||
# Fetch plugin model providers
|
||||
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
|
||||
|
||||
for provider in plugin_providers:
|
||||
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
|
||||
plugin_model_providers.append(provider)
|
||||
|
||||
return plugin_model_providers
|
||||
|
||||
def get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
"""
|
||||
Get provider schema
|
||||
:param provider: provider name
|
||||
:return: provider schema
|
||||
"""
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
return plugin_model_provider_entity.declaration
|
||||
|
||||
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||
"""
|
||||
Get plugin model provider
|
||||
:param provider: provider name
|
||||
:return: provider schema
|
||||
"""
|
||||
if "/" not in provider:
|
||||
provider = str(ModelProviderID(provider))
|
||||
|
||||
# fetch plugin model providers
|
||||
plugin_model_provider_entities = self.get_plugin_model_providers()
|
||||
|
||||
# get the provider
|
||||
plugin_model_provider_entity = next(
|
||||
(p for p in plugin_model_provider_entities if p.declaration.provider == provider),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin_model_provider_entity:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
|
||||
return plugin_model_provider_entity
|
||||
|
||||
def provider_credentials_validate(self, *, provider: str, credentials: dict):
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
:return:
|
||||
"""
|
||||
# fetch plugin model provider
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
|
||||
# get provider_credential_schema and validate credentials according to the rules
|
||||
provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema
|
||||
if not provider_credential_schema:
|
||||
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
|
||||
|
||||
# validate provider credential schema
|
||||
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
|
||||
filtered_credentials = validator.validate_and_filter(credentials)
|
||||
|
||||
# validate the credentials, raise exception if validation failed
|
||||
self.plugin_model_manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_model_provider_entity.plugin_id,
|
||||
provider=plugin_model_provider_entity.provider,
|
||||
credentials=filtered_credentials,
|
||||
)
|
||||
|
||||
return filtered_credentials
|
||||
|
||||
def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict):
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials, credentials form defined in `model_credential_schema`.
|
||||
:return:
|
||||
"""
|
||||
# fetch plugin model provider
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
|
||||
# get model_credential_schema and validate credentials according to the rules
|
||||
model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema
|
||||
if not model_credential_schema:
|
||||
raise ValueError(f"Provider {provider} does not have model_credential_schema")
|
||||
|
||||
# validate model credential schema
|
||||
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
|
||||
filtered_credentials = validator.validate_and_filter(credentials)
|
||||
|
||||
# call validate_credentials method of model type to validate credentials, raise exception if validation failed
|
||||
self.plugin_model_manager.validate_model_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_model_provider_entity.plugin_id,
|
||||
provider=plugin_model_provider_entity.provider,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=filtered_credentials,
|
||||
)
|
||||
|
||||
return filtered_credentials
|
||||
|
||||
def get_model_schema(
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
||||
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
schema = self.plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_models(
|
||||
self,
|
||||
*,
|
||||
provider: str | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
provider_configs: list[ProviderConfig] | None = None,
|
||||
) -> list[SimpleProviderEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param provider_configs: list of provider configs
|
||||
:return: list of models
|
||||
"""
|
||||
provider_configs = provider_configs or []
|
||||
|
||||
# scan all providers
|
||||
plugin_model_provider_entities = self.get_plugin_model_providers()
|
||||
|
||||
# traverse all model_provider_extensions
|
||||
providers = []
|
||||
for plugin_model_provider_entity in plugin_model_provider_entities:
|
||||
# filter by provider if provider is present
|
||||
if provider and plugin_model_provider_entity.declaration.provider != provider:
|
||||
continue
|
||||
|
||||
# get provider schema
|
||||
provider_schema = plugin_model_provider_entity.declaration
|
||||
|
||||
model_types = provider_schema.supported_model_types
|
||||
if model_type:
|
||||
if model_type not in model_types:
|
||||
continue
|
||||
|
||||
model_types = [model_type]
|
||||
|
||||
all_model_type_models = []
|
||||
for model_schema in provider_schema.models:
|
||||
if model_schema.model_type != model_type:
|
||||
continue
|
||||
|
||||
all_model_type_models.append(model_schema)
|
||||
|
||||
simple_provider_schema = provider_schema.to_simple_provider()
|
||||
simple_provider_schema.models.extend(all_model_type_models)
|
||||
|
||||
providers.append(simple_provider_schema)
|
||||
|
||||
return providers
|
||||
|
||||
def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel:
|
||||
"""
|
||||
Get model type instance by provider name and model type
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:return: model type instance
|
||||
"""
|
||||
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
||||
init_params = {
|
||||
"tenant_id": self.tenant_id,
|
||||
"plugin_id": plugin_id,
|
||||
"provider_name": provider_name,
|
||||
"plugin_model_provider": self.get_plugin_model_provider(provider),
|
||||
}
|
||||
|
||||
if model_type == ModelType.LLM:
|
||||
return LargeLanguageModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TEXT_EMBEDDING:
|
||||
return TextEmbeddingModel.model_validate(init_params)
|
||||
elif model_type == ModelType.RERANK:
|
||||
return RerankModel.model_validate(init_params)
|
||||
elif model_type == ModelType.SPEECH2TEXT:
|
||||
return Speech2TextModel.model_validate(init_params)
|
||||
elif model_type == ModelType.MODERATION:
|
||||
return ModerationModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TTS:
|
||||
return TTSModel.model_validate(init_params)
|
||||
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
Get provider icon
|
||||
:param provider: provider name
|
||||
:param icon_type: icon type (icon_small or icon_small_dark)
|
||||
:param lang: language (zh_Hans or en_US)
|
||||
:return: provider icon
|
||||
"""
|
||||
# get the provider schema
|
||||
provider_schema = self.get_provider_schema(provider)
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small.en_US
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small_dark.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small_dark.en_US
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
image_mime_types = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"bmp": "image/bmp",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
"webp": "image/webp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/vnd.microsoft.icon",
|
||||
"heif": "image/heif",
|
||||
"heic": "image/heic",
|
||||
}
|
||||
|
||||
extension = file_name.split(".")[-1]
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
|
||||
# get icon bytes from plugin asset manager
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
|
||||
plugin_asset_manager = PluginAssetManager()
|
||||
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]:
|
||||
"""
|
||||
Get plugin id and provider name from provider name
|
||||
:param provider: provider name
|
||||
:return: plugin id and provider name
|
||||
"""
|
||||
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
@ -0,0 +1,92 @@
|
||||
from typing import Union, cast
|
||||
|
||||
from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType
|
||||
|
||||
|
||||
class CommonValidator:
|
||||
def _validate_and_filter_credential_form_schemas(
|
||||
self, credential_form_schemas: list[CredentialFormSchema], credentials: dict
|
||||
):
|
||||
need_validate_credential_form_schema_map = {}
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if not credential_form_schema.show_on:
|
||||
need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema
|
||||
continue
|
||||
|
||||
all_show_on_match = True
|
||||
for show_on_object in credential_form_schema.show_on:
|
||||
if show_on_object.variable not in credentials:
|
||||
all_show_on_match = False
|
||||
break
|
||||
|
||||
if credentials[show_on_object.variable] != show_on_object.value:
|
||||
all_show_on_match = False
|
||||
break
|
||||
|
||||
if all_show_on_match:
|
||||
need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema
|
||||
|
||||
# Iterate over the remaining credential_form_schemas, verify each credential_form_schema
|
||||
validated_credentials = {}
|
||||
for credential_form_schema in need_validate_credential_form_schema_map.values():
|
||||
# add the value of the credential_form_schema corresponding to it to validated_credentials
|
||||
result = self._validate_credential_form_schema(credential_form_schema, credentials)
|
||||
if result:
|
||||
validated_credentials[credential_form_schema.variable] = result
|
||||
|
||||
return validated_credentials
|
||||
|
||||
def _validate_credential_form_schema(
|
||||
self, credential_form_schema: CredentialFormSchema, credentials: dict
|
||||
) -> Union[str, bool, None]:
|
||||
"""
|
||||
Validate credential form schema
|
||||
|
||||
:param credential_form_schema: credential form schema
|
||||
:param credentials: credentials
|
||||
:return: validated credential form schema value
|
||||
"""
|
||||
# If the variable does not exist in credentials
|
||||
value: Union[str, bool, None] = None
|
||||
if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]:
|
||||
# If required is True, an exception is thrown
|
||||
if credential_form_schema.required:
|
||||
raise ValueError(f"Variable {credential_form_schema.variable} is required")
|
||||
else:
|
||||
# Get the value of default
|
||||
if credential_form_schema.default:
|
||||
# If it exists, add it to validated_credentials
|
||||
return credential_form_schema.default
|
||||
else:
|
||||
# If default does not exist, skip
|
||||
return None
|
||||
|
||||
# Get the value corresponding to the variable from credentials
|
||||
value = cast(str, credentials[credential_form_schema.variable])
|
||||
|
||||
# If max_length=0, no validation is performed
|
||||
if credential_form_schema.max_length:
|
||||
if len(value) > credential_form_schema.max_length:
|
||||
raise ValueError(
|
||||
f"Variable {credential_form_schema.variable} length should not be"
|
||||
f" greater than {credential_form_schema.max_length}"
|
||||
)
|
||||
|
||||
# check the type of value
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"Variable {credential_form_schema.variable} should be string")
|
||||
|
||||
if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}:
|
||||
# If the value is in options, no validation is performed
|
||||
if credential_form_schema.options:
|
||||
if value not in [option.value for option in credential_form_schema.options]:
|
||||
raise ValueError(f"Variable {credential_form_schema.variable} is not in options")
|
||||
|
||||
if credential_form_schema.type == FormType.SWITCH:
|
||||
# If the value is not in ['true', 'false'], an exception is thrown
|
||||
if value.lower() not in {"true", "false"}:
|
||||
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
|
||||
|
||||
value = value.lower() == "true"
|
||||
|
||||
return value
|
||||
@ -0,0 +1,27 @@
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ModelCredentialSchema
|
||||
from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator
|
||||
|
||||
|
||||
class ModelCredentialSchemaValidator(CommonValidator):
|
||||
def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema):
|
||||
self.model_type = model_type
|
||||
self.model_credential_schema = model_credential_schema
|
||||
|
||||
def validate_and_filter(self, credentials: dict):
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param credentials: model credentials
|
||||
:return: filtered credentials
|
||||
"""
|
||||
|
||||
if self.model_credential_schema is None:
|
||||
raise ValueError("Model credential schema is None")
|
||||
|
||||
# get the credential_form_schemas in provider_credential_schema
|
||||
credential_form_schemas = self.model_credential_schema.credential_form_schemas
|
||||
|
||||
credentials["__model_type"] = self.model_type.value
|
||||
|
||||
return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials)
|
||||
@ -0,0 +1,19 @@
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderCredentialSchema
|
||||
from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator
|
||||
|
||||
|
||||
class ProviderCredentialSchemaValidator(CommonValidator):
|
||||
def __init__(self, provider_credential_schema: ProviderCredentialSchema):
|
||||
self.provider_credential_schema = provider_credential_schema
|
||||
|
||||
def validate_and_filter(self, credentials: dict):
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return: validated provider credentials
|
||||
"""
|
||||
# get the credential_form_schemas in provider_credential_schema
|
||||
credential_form_schemas = self.provider_credential_schema.credential_form_schemas
|
||||
|
||||
return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials)
|
||||
0
api/dify_graph/model_runtime/utils/__init__.py
Normal file
0
api/dify_graph/model_runtime/utils/__init__.py
Normal file
216
api/dify_graph/model_runtime/utils/encoders.py
Normal file
216
api/dify_graph/model_runtime/utils/encoders.py
Normal file
@ -0,0 +1,216 @@
|
||||
import dataclasses
|
||||
import datetime
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path, PurePath
|
||||
from re import Pattern
|
||||
from types import GeneratorType
|
||||
from typing import Any, Literal, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import AnyUrl, NameEmail
|
||||
from pydantic.types import SecretBytes, SecretStr
|
||||
from pydantic_core import Url
|
||||
from pydantic_extra_types.color import Color
|
||||
|
||||
|
||||
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
|
||||
return model.model_dump(mode=mode, **kwargs)
|
||||
|
||||
|
||||
# Taken from Pydantic v1 as is
|
||||
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
|
||||
return o.isoformat()
|
||||
|
||||
|
||||
# Taken from Pydantic v1 as is
|
||||
# TODO: pv2 should this return strings instead?
|
||||
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
||||
"""
|
||||
Encodes a Decimal as int of there's no exponent, otherwise float
|
||||
|
||||
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
||||
where a integer (but not int typed) is used. Encoding this as a float
|
||||
results in failed round-tripping between encode and parse.
|
||||
Our Id type is a prime example of this.
|
||||
|
||||
>>> decimal_encoder(Decimal("1.0"))
|
||||
1.0
|
||||
|
||||
>>> decimal_encoder(Decimal("1"))
|
||||
1
|
||||
"""
|
||||
if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
|
||||
return int(dec_value)
|
||||
else:
|
||||
return float(dec_value)
|
||||
|
||||
|
||||
ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
|
||||
bytes: lambda o: o.decode(),
|
||||
Color: str,
|
||||
datetime.date: isoformat,
|
||||
datetime.datetime: isoformat,
|
||||
datetime.time: isoformat,
|
||||
datetime.timedelta: lambda td: td.total_seconds(),
|
||||
Decimal: decimal_encoder,
|
||||
Enum: lambda o: o.value,
|
||||
frozenset: list,
|
||||
deque: list,
|
||||
GeneratorType: list,
|
||||
IPv4Address: str,
|
||||
IPv4Interface: str,
|
||||
IPv4Network: str,
|
||||
IPv6Address: str,
|
||||
IPv6Interface: str,
|
||||
IPv6Network: str,
|
||||
NameEmail: str,
|
||||
Path: str,
|
||||
Pattern: lambda o: o.pattern,
|
||||
SecretBytes: str,
|
||||
SecretStr: str,
|
||||
set: list,
|
||||
UUID: str,
|
||||
Url: str,
|
||||
AnyUrl: str,
|
||||
}
|
||||
|
||||
|
||||
def generate_encoders_by_class_tuples(
|
||||
type_encoder_map: dict[Any, Callable[[Any], Any]],
|
||||
) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
|
||||
encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple)
|
||||
for type_, encoder in type_encoder_map.items():
|
||||
encoders_by_class_tuples[encoder] += (type_,)
|
||||
return encoders_by_class_tuples
|
||||
|
||||
|
||||
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
by_alias: bool = True,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
custom_encoder: dict[Any, Callable[[Any], Any]] | None = None,
|
||||
sqlalchemy_safe: bool = True,
|
||||
) -> Any:
|
||||
custom_encoder = custom_encoder or {}
|
||||
if custom_encoder:
|
||||
if type(obj) in custom_encoder:
|
||||
return custom_encoder[type(obj)](obj)
|
||||
else:
|
||||
for encoder_type, encoder_instance in custom_encoder.items():
|
||||
if isinstance(obj, encoder_type):
|
||||
return encoder_instance(obj)
|
||||
if isinstance(obj, BaseModel):
|
||||
obj_dict = _model_dump(
|
||||
obj,
|
||||
mode="json",
|
||||
include=None,
|
||||
exclude=None,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
exclude_defaults=exclude_defaults,
|
||||
)
|
||||
if "__root__" in obj_dict:
|
||||
obj_dict = obj_dict["__root__"]
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
exclude_none=exclude_none,
|
||||
exclude_defaults=exclude_defaults,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if dataclasses.is_dataclass(obj):
|
||||
# Ensure obj is a dataclass instance, not a dataclass type
|
||||
if not isinstance(obj, type):
|
||||
obj_dict = dataclasses.asdict(obj)
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
if isinstance(obj, PurePath):
|
||||
return str(obj)
|
||||
if isinstance(obj, str | int | float | type(None)):
|
||||
return obj
|
||||
if isinstance(obj, Decimal):
|
||||
return format(obj, "f")
|
||||
if isinstance(obj, dict):
|
||||
encoded_dict = {}
|
||||
for key, value in obj.items():
|
||||
if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (
|
||||
value is not None or not exclude_none
|
||||
):
|
||||
encoded_key = jsonable_encoder(
|
||||
key,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
encoded_value = jsonable_encoder(
|
||||
value,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
encoded_dict[encoded_key] = encoded_value
|
||||
return encoded_dict
|
||||
if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque):
|
||||
encoded_list = []
|
||||
for item in obj:
|
||||
encoded_list.append(
|
||||
jsonable_encoder(
|
||||
item,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
)
|
||||
return encoded_list
|
||||
|
||||
if type(obj) in ENCODERS_BY_TYPE:
|
||||
return ENCODERS_BY_TYPE[type(obj)](obj)
|
||||
for encoder, classes_tuple in encoders_by_class_tuples.items():
|
||||
if isinstance(obj, classes_tuple):
|
||||
return encoder(obj)
|
||||
|
||||
try:
|
||||
data = dict(obj) # type: ignore
|
||||
except Exception as e:
|
||||
errors: list[Exception] = []
|
||||
errors.append(e)
|
||||
try:
|
||||
data = vars(obj) # type: ignore
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
raise ValueError(str(errors)) from e
|
||||
return jsonable_encoder(
|
||||
data,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
Reference in New Issue
Block a user