Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
69
api/core/model_runtime/README.md
Normal file
@ -0,0 +1,69 @@
|
||||
# Model Runtime
|
||||
|
||||
This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
|
||||
|
||||
- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
|
||||
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
|
||||
|
||||
## Features
|
||||
|
||||
- Supports capability invocation for 5 types of models
|
||||
|
||||
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
|
||||
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
|
||||
- `Rerank Model` - Segment Rerank capability
|
||||
- `Speech-to-text Model` - Speech to text capability
|
||||
- `Moderation` - Moderation capability
|
||||
|
||||
- Model provider display
|
||||
|
||||

|
||||
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./schema.md).
|
||||
|
||||
- Selectable model list display
|
||||
|
||||

|
||||
|
||||
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
|
||||
|
||||
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
|
||||
|
||||

|
||||
|
||||
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
|
||||
|
||||
- Provider/model credential authentication
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
|
||||
|
||||
## Structure
|
||||
|
||||

|
||||
|
||||
Model Runtime is divided into three layers:
|
||||
|
||||
- The outermost layer is the factory method
|
||||
|
||||
It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
|
||||
|
||||
- The second layer is the provider layer
|
||||
|
||||
It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
|
||||
|
||||
- The bottom layer is the model layer
|
||||
|
||||
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
|
||||
|
||||
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
|
||||
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
|
||||
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
|
||||
- Implement interface methods: [Link](./docs/en_US/interfaces.md)
|
||||
88
api/core/model_runtime/README_CN.md
Normal file
@ -0,0 +1,88 @@
|
||||
# Model Runtime
|
||||
|
||||
该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。
|
||||
|
||||
- 一方面将模型和上下游解耦,方便开发者对模型横向扩展,
|
||||
- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。
|
||||
|
||||
## 功能介绍
|
||||
|
||||
- 支持 5 种模型类型的能力调用
|
||||
|
||||
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
|
||||
- `Text Embedidng Model` - 文本 Embedding ,预计算 tokens 能力
|
||||
- `Rerank Model` - 分段 Rerank 能力
|
||||
- `Speech-to-text Model` - 语音转文本能力
|
||||
- `Moderation` - Moderation 能力
|
||||
|
||||
- 模型供应商展示
|
||||
|
||||

|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./schema.md)。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||

|
||||
|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
|
||||
|
||||

|
||||
|
||||
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
|
||||
|
||||
- 供应商/模型凭据鉴权
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。
|
||||
|
||||
## 结构
|
||||
|
||||

|
||||
|
||||
Model Runtime 分三层:
|
||||
|
||||
- 最外层为工厂方法
|
||||
|
||||
提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。
|
||||
|
||||
- 第二层为供应商层
|
||||
|
||||
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
|
||||
|
||||
对于供应商/模型凭据,有两种情况
|
||||
- 如OpenAI这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||

|
||||
|
||||
当配置好凭据后,就可以通过DifyRuntime的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
||||
|
||||
- 最底层为模型层
|
||||
|
||||
提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。
|
||||
|
||||
在这里我们需要先区分模型参数与模型凭据。
|
||||
|
||||
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在DifyRuntime中,他们的参数名一般为**model_parameters: dict[str, any]**。
|
||||
|
||||
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中,他们的参数名一般为**credentials: dict[str, any]**,Provider层的credentials会直接被传递到这一层,不需要再单独定义。
|
||||
|
||||
## 下一步
|
||||
|
||||
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
|
||||
当添加后,这里将会出现一个新的供应商
|
||||
|
||||

|
||||
|
||||
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
|
||||
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如GPT-3.5 GPT-4 ChatGLM3-6b等,而对于支持自定义模型的供应商,则不需要新增模型。
|
||||
|
||||

|
||||
|
||||
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
|
||||
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
||||
0
api/core/model_runtime/__init__.py
Normal file
0
api/core/model_runtime/callbacks/__init__.py
Normal file
113
api/core/model_runtime/callbacks/base_callback.py
Normal file
@ -0,0 +1,113 @@
|
||||
from abc import ABC
|
||||
from typing import Optional, List
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
class Callback(ABC):
|
||||
"""
|
||||
Base class for callbacks.
|
||||
Only for LLM.
|
||||
"""
|
||||
raise_error: bool = False
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def print_text(
|
||||
self, text: str, color: Optional[str] = None, end: str = ""
|
||||
) -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
133
api/core/model_runtime/callbacks/logging_callback.py
Normal file
@ -0,0 +1,133 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional, List
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoggingCallback(Callback):
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_before_invoke]\n", color='blue')
|
||||
self.print_text(f"Model: {model}\n", color='blue')
|
||||
self.print_text(f"Parameters:\n", color='blue')
|
||||
for key, value in model_parameters.items():
|
||||
self.print_text(f"\t{key}: {value}\n", color='blue')
|
||||
|
||||
if stop:
|
||||
self.print_text(f"\tstop: {stop}\n", color='blue')
|
||||
|
||||
if tools:
|
||||
self.print_text(f"\tTools:\n", color='blue')
|
||||
for tool in tools:
|
||||
self.print_text(f"\t\t{tool.name}\n", color='blue')
|
||||
|
||||
self.print_text(f"Stream: {stream}\n", color='blue')
|
||||
|
||||
if user:
|
||||
self.print_text(f"User: {user}\n", color='blue')
|
||||
|
||||
self.print_text(f"Prompt messages:\n", color='blue')
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.name:
|
||||
self.print_text(f"\tname: {prompt_message.name}\n", color='blue')
|
||||
|
||||
self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue')
|
||||
self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue')
|
||||
|
||||
if stream:
|
||||
self.print_text("\n[on_llm_new_chunk]")
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
sys.stdout.write(chunk.delta.message.content)
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_after_invoke]\n", color='yellow')
|
||||
self.print_text(f"Content: {result.message.content}\n", color='yellow')
|
||||
|
||||
if result.message.tool_calls:
|
||||
self.print_text(f"Tool calls:\n", color='yellow')
|
||||
for tool_call in result.message.tool_calls:
|
||||
self.print_text(f"\t{tool_call.id}\n", color='yellow')
|
||||
self.print_text(f"\t{tool_call.function.name}\n", color='yellow')
|
||||
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow')
|
||||
|
||||
self.print_text(f"Model: {result.model}\n", color='yellow')
|
||||
self.print_text(f"Usage: {result.usage}\n", color='yellow')
|
||||
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow')
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_invoke_error]\n", color='red')
|
||||
logger.exception(ex)
|
||||
|
After Width: | Height: | Size: 370 KiB |
|
After Width: | Height: | Size: 113 KiB |
|
After Width: | Height: | Size: 109 KiB |
|
After Width: | Height: | Size: 70 KiB |
|
After Width: | Height: | Size: 75 KiB |
|
After Width: | Height: | Size: 541 KiB |
668
api/core/model_runtime/docs/en_US/interfaces.md
Normal file
@ -0,0 +1,668 @@
|
||||
# Interface Methods
|
||||
|
||||
This section describes the interface methods and parameter explanations that need to be implemented by providers and various model types.
|
||||
|
||||
## Provider
|
||||
|
||||
Inherit the `__base.model_provider.ModelProvider` base class and implement the following interfaces:
|
||||
|
||||
```python
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
```
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by the `provider_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
If verification fails, throw the `errors.validate.CredentialsValidateFailedError` error.
|
||||
|
||||
## Model
|
||||
|
||||
Models are divided into 5 different types, each inheriting from different base classes and requiring the implementation of different methods.
|
||||
|
||||
All models need to uniformly implement the following 2 methods:
|
||||
|
||||
- Model Credential Verification
|
||||
|
||||
Similar to provider credential verification, this step involves verification for an individual model.
|
||||
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
If verification fails, throw the `errors.validate.CredentialsValidateFailedError` error.
|
||||
|
||||
- Invocation Error Mapping Table
|
||||
|
||||
When there is an exception in model invocation, it needs to be mapped to the `InvokeError` type specified by Runtime. This facilitates Dify's ability to handle different errors with appropriate follow-up actions.
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` Invocation connection error
|
||||
- `InvokeServerUnavailableError` Invocation service provider unavailable
|
||||
- `InvokeRateLimitError` Invocation reached rate limit
|
||||
- `InvokeAuthorizationError` Invocation authorization failure
|
||||
- `InvokeBadRequestError` Invocation parameter error
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
You can refer to OpenAI's `_invoke_error_mapping` for an example.
|
||||
|
||||
### LLM
|
||||
|
||||
Inherit the `__base.large_language_model.LargeLanguageModel` base class and implement the following interfaces:
|
||||
|
||||
- LLM Invocation
|
||||
|
||||
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
|
||||
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `prompt_messages` (array[[PromptMessage](#PromptMessage)]) List of prompts
|
||||
|
||||
If the model is of the `Completion` type, the list only needs to include one [UserPromptMessage](#UserPromptMessage) element;
|
||||
|
||||
If the model is of the `Chat` type, it requires a list of elements such as [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) depending on the message.
|
||||
|
||||
- `model_parameters` (object) Model parameters
|
||||
|
||||
The model parameters are defined by the `parameter_rules` in the model's YAML configuration.
|
||||
|
||||
- `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] List of tools, equivalent to the `function` in `function calling`.
|
||||
|
||||
That is, the tool list for tool calling.
|
||||
|
||||
- `stop` (array[string]) [optional] Stop sequences
|
||||
|
||||
The model output will stop before the string defined by the stop sequence.
|
||||
|
||||
- `stream` (bool) Whether to output in a streaming manner, default is True
|
||||
|
||||
Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult).
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns
|
||||
|
||||
Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult).
|
||||
|
||||
- Pre-calculating Input Tokens
|
||||
|
||||
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
For parameter explanations, refer to the above section on `LLM Invocation`.
|
||||
|
||||
- Fetch Custom Model Schema [Optional]
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
```
|
||||
|
||||
When the provider supports adding custom LLMs, this method can be implemented to allow custom models to fetch model schema. The default return null.
|
||||
|
||||
|
||||
### TextEmbedding
|
||||
|
||||
Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and implement the following interfaces:
|
||||
|
||||
- Embedding Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `texts` (array[string]) List of texts, capable of batch processing
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
[TextEmbeddingResult](#TextEmbeddingResult) entity.
|
||||
|
||||
- Pre-calculating Tokens
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
For parameter explanations, refer to the above section on `Embedding Invocation`.
|
||||
|
||||
### Rerank
|
||||
|
||||
Inherit the `__base.rerank_model.RerankModel` base class and implement the following interfaces:
|
||||
|
||||
- Rerank Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `query` (string) Query request content
|
||||
|
||||
- `docs` (array[string]) List of segments to be reranked
|
||||
|
||||
- `score_threshold` (float) [optional] Score threshold
|
||||
|
||||
- `top_n` (int) [optional] Select the top n segments
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
[RerankResult](#RerankResult) entity.
|
||||
|
||||
### Speech2text
|
||||
|
||||
Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement the following interfaces:
|
||||
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `file` (File) File stream
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
The string after speech-to-text conversion.
|
||||
|
||||
### Moderation
|
||||
|
||||
Inherit the `__base.moderation_model.ModerationModel` base class and implement the following interfaces:
|
||||
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `text` (string) Text content
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
False indicates that the input text is safe, True indicates otherwise.
|
||||
|
||||
|
||||
|
||||
## Entities
|
||||
|
||||
### PromptMessageRole
|
||||
|
||||
Message role
|
||||
|
||||
```python
|
||||
class PromptMessageRole(Enum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
```
|
||||
|
||||
### PromptMessageContentType
|
||||
|
||||
Message content types, divided into text and image.
|
||||
|
||||
```python
|
||||
class PromptMessageContentType(Enum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
TEXT = 'text'
|
||||
IMAGE = 'image'
|
||||
```
|
||||
|
||||
### PromptMessageContent
|
||||
|
||||
Message content base class, used only for parameter declaration and cannot be initialized.
|
||||
|
||||
```python
|
||||
class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType
|
||||
data: str
|
||||
```
|
||||
|
||||
Currently, two types are supported: text and image. It's possible to simultaneously input text and multiple images.
|
||||
|
||||
You need to initialize `TextPromptMessageContent` and `ImagePromptMessageContent` separately for input.
|
||||
|
||||
### TextPromptMessageContent
|
||||
|
||||
```python
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
```
|
||||
|
||||
If inputting a combination of text and images, the text needs to be constructed into this entity as part of the `content` list.
|
||||
|
||||
### ImagePromptMessageContent
|
||||
|
||||
```python
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
class DETAIL(Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW # Resolution
|
||||
```
|
||||
|
||||
If inputting a combination of text and images, the images need to be constructed into this entity as part of the `content` list.
|
||||
|
||||
`data` can be either a `url` or a `base64` encoded string of the image.
|
||||
|
||||
### PromptMessage
|
||||
|
||||
The base class for all Role message bodies, used only for parameter declaration and cannot be initialized.
|
||||
|
||||
```python
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
|
||||
name: Optional[str] = None
|
||||
```
|
||||
|
||||
### UserPromptMessage
|
||||
|
||||
UserMessage message body, representing a user's message.
|
||||
|
||||
```python
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for user prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.USER
|
||||
```
|
||||
|
||||
### AssistantPromptMessage
|
||||
|
||||
Represents a message returned by the model, typically used for `few-shots` or inputting chat history.
|
||||
|
||||
```python
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for assistant prompt message.
|
||||
"""
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call.
|
||||
"""
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call function.
|
||||
"""
|
||||
name: str # tool name
|
||||
arguments: str # tool arguments
|
||||
|
||||
id: str # Tool ID, effective only in OpenAI tool calls. It's the unique ID for tool invocation and the same tool can be called multiple times.
|
||||
type: str # default: function
|
||||
function: ToolCallFunction # tool call information
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
|
||||
```
|
||||
|
||||
Where `tool_calls` are the list of `tool calls` returned by the model after invoking the model with the `tools` input.
|
||||
|
||||
### SystemPromptMessage
|
||||
|
||||
Represents system messages, usually used for setting system commands given to the model.
|
||||
|
||||
```python
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for system prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||
```
|
||||
|
||||
### ToolPromptMessage
|
||||
|
||||
Represents tool messages, used for conveying the results of a tool execution to the model for the next step of processing.
|
||||
|
||||
```python
|
||||
class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for tool prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str # Tool invocation ID. If OpenAI tool call is not supported, the name of the tool can also be inputted.
|
||||
```
|
||||
|
||||
The base class's `content` takes in the results of tool execution.
|
||||
|
||||
### PromptMessageTool
|
||||
|
||||
```python
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
Model class for prompt message tool.
|
||||
"""
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### LLMResult
|
||||
|
||||
```python
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
Model class for llm result.
|
||||
"""
|
||||
model: str # Actual used modele
|
||||
prompt_messages: list[PromptMessage] # prompt messages
|
||||
message: AssistantPromptMessage # response message
|
||||
usage: LLMUsage # usage info
|
||||
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
|
||||
```
|
||||
|
||||
### LLMResultChunkDelta
|
||||
|
||||
In streaming returns, each iteration contains the `delta` entity.
|
||||
|
||||
```python
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
"""
|
||||
index: int
|
||||
message: AssistantPromptMessage # response message
|
||||
usage: Optional[LLMUsage] = None # usage info
|
||||
finish_reason: Optional[str] = None # finish reason, only the last one returns
|
||||
```
|
||||
|
||||
### LLMResultChunk
|
||||
|
||||
Each iteration entity in streaming returns.
|
||||
|
||||
```python
|
||||
class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
model: str # Actual used modele
|
||||
prompt_messages: list[PromptMessage] # prompt messages
|
||||
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
|
||||
delta: LLMResultChunkDelta
|
||||
```
|
||||
|
||||
### LLMUsage
|
||||
|
||||
```python
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for LLM usage.
|
||||
"""
|
||||
prompt_tokens: int # Tokens used for prompt
|
||||
prompt_unit_price: Decimal # Unit price for prompt
|
||||
prompt_price_unit: Decimal # Price unit for prompt, i.e., the unit price based on how many tokens
|
||||
prompt_price: Decimal # Cost for prompt
|
||||
completion_tokens: int # Tokens used for response
|
||||
completion_unit_price: Decimal # Unit price for response
|
||||
completion_price_unit: Decimal # Price unit for response, i.e., the unit price based on how many tokens
|
||||
completion_price: Decimal # Cost for response
|
||||
total_tokens: int # Total number of tokens used
|
||||
total_price: Decimal # Total cost
|
||||
currency: str # Currency unit
|
||||
latency: float # Request latency (s)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### TextEmbeddingResult
|
||||
|
||||
```python
|
||||
class TextEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
model: str # Actual model used
|
||||
embeddings: list[list[float]] # List of embedding vectors, corresponding to the input texts list
|
||||
usage: EmbeddingUsage # Usage information
|
||||
```
|
||||
|
||||
### EmbeddingUsage
|
||||
|
||||
```python
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
"""
|
||||
tokens: int # Number of tokens used
|
||||
total_tokens: int # Total number of tokens used
|
||||
unit_price: Decimal # Unit price
|
||||
price_unit: Decimal # Price unit, i.e., the unit price based on how many tokens
|
||||
total_price: Decimal # Total cost
|
||||
currency: str # Currency unit
|
||||
latency: float # Request latency (s)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### RerankResult
|
||||
|
||||
```python
|
||||
class RerankResult(BaseModel):
|
||||
"""
|
||||
Model class for rerank result.
|
||||
"""
|
||||
model: str # Actual model used
|
||||
docs: list[RerankDocument] # Reranked document list
|
||||
```
|
||||
|
||||
### RerankDocument
|
||||
|
||||
```python
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
"""
|
||||
index: int # original index
|
||||
text: str
|
||||
score: float
|
||||
```
|
||||
264
api/core/model_runtime/docs/en_US/provider_scale_out.md
Normal file
@ -0,0 +1,264 @@
|
||||
## Adding a New Provider
|
||||
|
||||
Providers support three types of model configuration methods:
|
||||
|
||||
- `predefined-model` Predefined model
|
||||
|
||||
This indicates that users only need to configure the unified provider credentials to use the predefined models under the provider.
|
||||
|
||||
- `customizable-model` Customizable model
|
||||
|
||||
Users need to add credential configurations for each model.
|
||||
|
||||
- `fetch-from-remote` Fetch from remote
|
||||
|
||||
This is consistent with the `predefined-model` configuration method. Only unified provider credentials need to be configured, and models are obtained from the provider through credential information.
|
||||
|
||||
These three configuration methods **can coexist**, meaning a provider can support `predefined-model` + `customizable-model` or `predefined-model` + `fetch-from-remote`, etc. In other words, configuring the unified provider credentials allows the use of predefined and remotely fetched models, and if new models are added, they can be used in addition to the custom models.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Adding a new provider starts with determining the English identifier of the provider, such as `anthropic`, and using this identifier to create a `module` in `model_providers`.
|
||||
|
||||
Under this `module`, we first need to prepare the provider's YAML configuration.
|
||||
|
||||
### Preparing Provider YAML
|
||||
|
||||
Here, using `Anthropic` as an example, we preset the provider's basic information, supported model types, configuration methods, and credential rules.
|
||||
|
||||
```YAML
|
||||
provider: anthropic # Provider identifier
|
||||
label: # Provider display name, can be set in en_US English and zh_Hans Chinese, zh_Hans will default to en_US if not set.
|
||||
en_US: Anthropic
|
||||
icon_small: # Small provider icon, stored in the _assets directory under the corresponding provider implementation directory, same language strategy as label
|
||||
en_US: icon_s_en.png
|
||||
icon_large: # Large provider icon, stored in the _assets directory under the corresponding provider implementation directory, same language strategy as label
|
||||
en_US: icon_l_en.png
|
||||
supported_model_types: # Supported model types, Anthropic only supports LLM
|
||||
- llm
|
||||
configurate_methods: # Supported configuration methods, Anthropic only supports predefined models
|
||||
- predefined-model
|
||||
provider_credential_schema: # Provider credential rules, as Anthropic only supports predefined models, unified provider credential rules need to be defined
|
||||
credential_form_schemas: # List of credential form items
|
||||
- variable: anthropic_api_key # Credential parameter variable name
|
||||
label: # Display name
|
||||
en_US: API Key
|
||||
type: secret-input # Form type, here secret-input represents an encrypted information input box, showing masked information when editing.
|
||||
required: true # Whether required
|
||||
placeholder: # Placeholder information
|
||||
zh_Hans: Enter your API Key here
|
||||
en_US: Enter your API Key
|
||||
- variable: anthropic_api_url
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input # Form type, here text-input represents a text input box
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: Enter your API URL here
|
||||
en_US: Enter your API URL
|
||||
```
|
||||
|
||||
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#Provider).
|
||||
|
||||
### Implementing Provider Code
|
||||
|
||||
Providers need to inherit the `__base.model_provider.ModelProvider` base class and implement the `validate_provider_credentials` method for unified provider credential verification. For reference, see [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py).
|
||||
> If the provider is the type of `customizable-model`, there is no need to implement the `validate_provider_credentials` method.
|
||||
|
||||
```python
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
```
|
||||
|
||||
Of course, you can also preliminarily reserve the implementation of `validate_provider_credentials` and directly reuse it after the model credential verification method is implemented.
|
||||
|
||||
---
|
||||
|
||||
### Adding Models
|
||||
|
||||
After the provider integration is complete, the next step is to integrate models under the provider.
|
||||
|
||||
First, we need to determine the type of the model to be integrated and create a `module` for the corresponding model type in the provider's directory.
|
||||
|
||||
The currently supported model types are as follows:
|
||||
|
||||
- `llm` Text generation model
|
||||
- `text_embedding` Text Embedding model
|
||||
- `rerank` Rerank model
|
||||
- `speech2text` Speech to text
|
||||
- `moderation` Moderation
|
||||
|
||||
Continuing with `Anthropic` as an example, since `Anthropic` only supports LLM, we create a `module` named `llm` in `model_providers.anthropic`.
|
||||
|
||||
For predefined models, we first need to create a YAML file named after the model, such as `claude-2.1.yaml`, under the `llm` `module`.
|
||||
|
||||
#### Preparing Model YAML
|
||||
|
||||
```yaml
|
||||
model: claude-2.1 # Model identifier
|
||||
# Model display name, can be set in en_US English and zh_Hans Chinese, zh_Hans will default to en_US if not set.
|
||||
# Alternatively, if the label is not set, use the model identifier content.
|
||||
label:
|
||||
en_US: claude-2.1
|
||||
model_type: llm # Model type, claude-2.1 is an LLM
|
||||
features: # Supported features, agent-thought for Agent reasoning, vision for image understanding
|
||||
- agent-thought
|
||||
model_properties: # Model properties
|
||||
mode: chat # LLM mode, complete for text completion model, chat for dialogue model
|
||||
context_size: 200000 # Maximum supported context size
|
||||
parameter_rules: # Model invocation parameter rules, only required for LLM
|
||||
- name: temperature # Invocation parameter variable name
|
||||
# Default preset with 5 variable content configuration templates: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
|
||||
# Directly set the template variable name in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
|
||||
# If additional configuration parameters are set, they will override the default configuration
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label: # Invocation parameter display name
|
||||
zh_Hans: Sampling quantity
|
||||
en_US: Top k
|
||||
type: int # Parameter type, supports float/int/string/boolean
|
||||
help: # Help information, describing the role of the parameter
|
||||
zh_Hans: Only sample from the top K options for each subsequent token.
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false # Whether required, can be left unset
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
default: 4096 # Default parameter value
|
||||
min: 1 # Minimum parameter value, only applicable for float/int
|
||||
max: 4096 # Maximum parameter value, only applicable for float/int
|
||||
pricing: # Pricing information
|
||||
input: '8.00' # Input price, i.e., Prompt price
|
||||
output: '24.00' # Output price, i.e., returned content price
|
||||
unit: '0.000001' # Pricing unit, i.e., the above prices are per 100K
|
||||
currency: USD # Currency
|
||||
```
|
||||
|
||||
It is recommended to prepare all model configurations before starting the implementation of the model code.
|
||||
|
||||
Similarly, you can also refer to the YAML configuration information for corresponding model types of other providers in the `model_providers` directory. The complete YAML rules can be found at: [Schema](schema.md#AIModel).
|
||||
|
||||
#### Implementing Model Invocation Code
|
||||
|
||||
Next, you need to create a python file named `llm.py` under the `llm` `module` to write the implementation code.
|
||||
|
||||
In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguageModel` (arbitrarily), inheriting the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
|
||||
|
||||
- LLM Invocation
|
||||
|
||||
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
- Pre-calculating Input Tokens
|
||||
|
||||
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Model Credential Verification
|
||||
|
||||
Similar to provider credential verification, this step involves verification for an individual model.
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Invocation Error Mapping Table
|
||||
|
||||
When there is an exception in model invocation, it needs to be mapped to the `InvokeError` type specified by Runtime. This facilitates Dify's ability to handle different errors with appropriate follow-up actions.
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` Invocation connection error
|
||||
- `InvokeServerUnavailableError` Invocation service provider unavailable
|
||||
- `InvokeRateLimitError` Invocation reached rate limit
|
||||
- `InvokeAuthorizationError` Invocation authorization failure
|
||||
- `InvokeBadRequestError` Invocation parameter error
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
For details on the interface methods, see: [Interfaces](interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
|
||||
|
||||
### Testing
|
||||
|
||||
To ensure the availability of integrated providers/models, each method written needs corresponding integration test code in the `tests` directory.
|
||||
|
||||
Continuing with `Anthropic` as an example:
|
||||
|
||||
Before writing test code, you need to first add the necessary credential environment variables for the test provider in `.env.example`, such as: `ANTHROPIC_API_KEY`.
|
||||
|
||||
Before execution, copy `.env.example` to `.env` and then execute.
|
||||
|
||||
#### Writing Test Code
|
||||
|
||||
Create a `module` with the same name as the provider in the `tests` directory: `anthropic`, and continue to create `test_provider.py` and test py files for the corresponding model types within this module, as shown below:
|
||||
|
||||
```shell
|
||||
.
|
||||
├── __init__.py
|
||||
├── anthropic
|
||||
│ ├── __init__.py
|
||||
│ ├── test_llm.py # LLM Testing
|
||||
│ └── test_provider.py # Provider Testing
|
||||
```
|
||||
|
||||
Write test code for all the various cases implemented above and submit the code after passing the tests.
|
||||
194
api/core/model_runtime/docs/en_US/schema.md
Normal file
@ -0,0 +1,194 @@
|
||||
# Configuration Rules
|
||||
|
||||
- Provider rules are based on the [Provider](#Provider) entity.
|
||||
- Model rules are based on the [AIModelEntity](#AIModelEntity) entity.
|
||||
|
||||
> All entities mentioned below are based on `Pydantic BaseModel` and can be found in the `entities` module.
|
||||
|
||||
### Provider
|
||||
|
||||
- `provider` (string) Provider identifier, e.g., `openai`
|
||||
- `label` (object) Provider display name, i18n, with `en_US` English and `zh_Hans` Chinese language settings
|
||||
- `zh_Hans` (string) [optional] Chinese label name, if `zh_Hans` is not set, `en_US` will be used by default.
|
||||
- `en_US` (string) English label name
|
||||
- `description` (object) Provider description, i18n
|
||||
- `zh_Hans` (string) [optional] Chinese description
|
||||
- `en_US` (string) English description
|
||||
- `icon_small` (string) [optional] Small provider ICON, stored in the `_assets` directory under the corresponding provider implementation directory, with the same language strategy as `label`
|
||||
- `zh_Hans` (string) Chinese ICON
|
||||
- `en_US` (string) English ICON
|
||||
- `icon_large` (string) [optional] Large provider ICON, stored in the `_assets` directory under the corresponding provider implementation directory, with the same language strategy as `label`
|
||||
- `zh_Hans` (string) Chinese ICON
|
||||
- `en_US` (string) English ICON
|
||||
- `background` (string) [optional] Background color value, e.g., #FFFFFF, if empty, the default frontend color value will be displayed.
|
||||
- `help` (object) [optional] help information
|
||||
- `title` (object) help title, i18n
|
||||
- `zh_Hans` (string) [optional] Chinese title
|
||||
- `en_US` (string) English title
|
||||
- `url` (object) help link, i18n
|
||||
- `zh_Hans` (string) [optional] Chinese link
|
||||
- `en_US` (string) English link
|
||||
- `supported_model_types` (array[[ModelType](#ModelType)]) Supported model types
|
||||
- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) Configuration methods
|
||||
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) Provider credential specification
|
||||
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) Model credential specification
|
||||
|
||||
### AIModelEntity
|
||||
|
||||
- `model` (string) Model identifier, e.g., `gpt-3.5-turbo`
|
||||
- `label` (object) [optional] Model display name, i18n, with `en_US` English and `zh_Hans` Chinese language settings
|
||||
- `zh_Hans` (string) [optional] Chinese label name
|
||||
- `en_US` (string) English label name
|
||||
- `model_type` ([ModelType](#ModelType)) Model type
|
||||
- `features` (array[[ModelFeature](#ModelFeature)]) [optional] Supported feature list
|
||||
- `model_properties` (object) Model properties
|
||||
- `mode` ([LLMMode](#LLMMode)) Mode (available for model type `llm`)
|
||||
- `context_size` (int) Context size (available for model types `llm`, `text-embedding`)
|
||||
- `max_chunks` (int) Maximum number of chunks (available for model types `text-embedding`, `moderation`)
|
||||
- `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`)
|
||||
- `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`)
|
||||
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
|
||||
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] Model invocation parameter rules
|
||||
- `pricing` ([PriceConfig](#PriceConfig)) [optional] Pricing information
|
||||
- `deprecated` (bool) Whether deprecated. If deprecated, the model will no longer be displayed in the list, but those already configured can continue to be used. Default False.
|
||||
|
||||
### ModelType
|
||||
|
||||
- `llm` Text generation model
|
||||
- `text-embedding` Text Embedding model
|
||||
- `rerank` Rerank model
|
||||
- `speech2text` Speech to text
|
||||
- `moderation` Moderation
|
||||
|
||||
### ConfigurateMethod
|
||||
|
||||
- `predefined-model` Predefined model
|
||||
|
||||
Indicates that users can use the predefined models under the provider by configuring the unified provider credentials.
|
||||
- `customizable-model` Customizable model
|
||||
|
||||
Users need to add credential configuration for each model.
|
||||
|
||||
- `fetch-from-remote` Fetch from remote
|
||||
|
||||
Consistent with the `predefined-model` configuration method, only unified provider credentials need to be configured, and models are obtained from the provider through credential information.
|
||||
|
||||
### ModelFeature
|
||||
|
||||
- `agent-thought` Agent reasoning, generally over 70B with thought chain capability.
|
||||
- `vision` Vision, i.e., image understanding.
|
||||
|
||||
### FetchFrom
|
||||
|
||||
- `predefined-model` Predefined model
|
||||
- `fetch-from-remote` Remote model
|
||||
|
||||
### LLMMode
|
||||
|
||||
- `complete` Text completion
|
||||
- `chat` Dialogue
|
||||
|
||||
### ParameterRule
|
||||
|
||||
- `name` (string) Actual model invocation parameter name
|
||||
- `use_template` (string) [optional] Using template
|
||||
|
||||
By default, 5 variable content configuration templates are preset:
|
||||
|
||||
- `temperature`
|
||||
- `top_p`
|
||||
- `frequency_penalty`
|
||||
- `presence_penalty`
|
||||
- `max_tokens`
|
||||
|
||||
In use_template, you can directly set the template variable name, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
|
||||
No need to set any parameters other than `name` and `use_template`. If additional configuration parameters are set, they will override the default configuration.
|
||||
Refer to `openai/llm/gpt-3.5-turbo.yaml`.
|
||||
|
||||
- `label` (object) [optional] Label, i18n
|
||||
|
||||
- `zh_Hans`(string) [optional] Chinese label name
|
||||
- `en_US` (string) English label name
|
||||
|
||||
- `type`(string) [optional] Parameter type
|
||||
|
||||
- `int` Integer
|
||||
- `float` Float
|
||||
- `string` String
|
||||
- `boolean` Boolean
|
||||
|
||||
- `help` (string) [optional] Help information
|
||||
|
||||
- `zh_Hans` (string) [optional] Chinese help information
|
||||
- `en_US` (string) English help information
|
||||
|
||||
- `required` (bool) Required, default False.
|
||||
|
||||
- `default`(int/float/string/bool) [optional] Default value
|
||||
|
||||
- `min`(int/float) [optional] Minimum value, applicable only to numeric types
|
||||
|
||||
- `max`(int/float) [optional] Maximum value, applicable only to numeric types
|
||||
|
||||
- `precision`(int) [optional] Precision, number of decimal places to keep, applicable only to numeric types
|
||||
|
||||
- `options` (array[string]) [optional] Dropdown option values, applicable only when `type` is `string`, if not set or null, option values are not restricted
|
||||
|
||||
### PriceConfig
|
||||
|
||||
- `input` (float) Input price, i.e., Prompt price
|
||||
- `output` (float) Output price, i.e., returned content price
|
||||
- `unit` (float) Pricing unit, e.g., per 100K price is `0.000001`
|
||||
- `currency` (string) Currency unit
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
||||
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard
|
||||
|
||||
### ModelCredentialSchema
|
||||
|
||||
- `model` (object) Model identifier, variable name defaults to `model`
|
||||
- `label` (object) Model form item display name
|
||||
- `en_US` (string) English
|
||||
- `zh_Hans`(string) [optional] Chinese
|
||||
- `placeholder` (object) Model prompt content
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans`(string) [optional] Chinese
|
||||
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard
|
||||
|
||||
### CredentialFormSchema
|
||||
|
||||
- `variable` (string) Form item variable name
|
||||
- `label` (object) Form item label name
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans` (string) [optional] Chinese
|
||||
- `type` ([FormType](#FormType)) Form item type
|
||||
- `required` (bool) Whether required
|
||||
- `default`(string) Default value
|
||||
- `options` (array[[FormOption](#FormOption)]) Specific property of form items of type `select` or `radio`, defining dropdown content
|
||||
- `placeholder`(object) Specific property of form items of type `text-input`, placeholder content
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans` (string) [optional] Chinese
|
||||
- `max_length` (int) Specific property of form items of type `text-input`, defining maximum input length, 0 for no limit.
|
||||
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty.
|
||||
|
||||
### FormType
|
||||
|
||||
- `text-input` Text input component
|
||||
- `secret-input` Password input component
|
||||
- `select` Single-choice dropdown
|
||||
- `radio` Radio component
|
||||
- `switch` Switch component, only supports `true` and `false` values
|
||||
|
||||
### FormOption
|
||||
|
||||
- `label` (object) Label
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans`(string) [optional] Chinese
|
||||
- `value` (string) Dropdown option value
|
||||
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty.
|
||||
|
||||
### FormShowOnObject
|
||||
|
||||
- `variable` (string) Variable name of other form items
|
||||
- `value` (string) Variable value of other form items
|
||||
@ -0,0 +1,296 @@
|
||||
## 自定义预定义模型接入
|
||||
|
||||
### 介绍
|
||||
|
||||
供应商集成完成后,接下来为供应商下模型的接入,为了帮助理解整个接入过程,我们以`Xinference`为例,逐步完成一个完整的供应商接入。
|
||||
|
||||
需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。
|
||||
|
||||
而不同于预定义模型,自定义供应商接入时永远会拥有如下两个参数,不需要在供应商yaml中定义。
|
||||
|
||||

|
||||
|
||||
|
||||
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`,Runtime会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
|
||||
|
||||
### 编写供应商yaml
|
||||
|
||||
我们首先要确定,接入的这个供应商支持哪些类型的模型。
|
||||
|
||||
当前支持模型类型如下:
|
||||
|
||||
- `llm` 文本生成模型
|
||||
- `text_embedding` 文本 Embedding 模型
|
||||
- `rerank` Rerank 模型
|
||||
- `speech2text` 语音转文字
|
||||
- `moderation` 审查
|
||||
|
||||
`Xinference`支持`LLM`和`Text Embedding`和Rerank,那么我们开始编写`xinference.yaml`。
|
||||
|
||||
```yaml
|
||||
provider: xinference #确定供应商标识
|
||||
label: # 供应商展示名称,可设置 en_US 英文、zh_Hans 中文两种语言,zh_Hans 不设置将默认使用 en_US。
|
||||
en_US: Xorbots Inference
|
||||
icon_small: # 小图标,可以参考其他供应商的图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
|
||||
en_US: icon_s_en.svg
|
||||
icon_large: # 大图标
|
||||
en_US: icon_l_en.svg
|
||||
help: # 帮助
|
||||
title:
|
||||
en_US: How to deploy Xinference
|
||||
zh_Hans: 如何部署 Xinference
|
||||
url:
|
||||
en_US: https://github.com/xorbitsai/inference
|
||||
supported_model_types: # 支持的模型类型,Xinference同时支持LLM/Text Embedding/Rerank
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods: # 因为Xinference为本地部署的供应商,并且没有预定义模型,需要用什么模型需要根据Xinference的文档自己部署,所以这里只支持自定义模型
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
```
|
||||
|
||||
随后,我们需要思考在Xinference中定义一个模型需要哪些凭据
|
||||
|
||||
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
|
||||
```yaml
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: model_type
|
||||
type: select
|
||||
label:
|
||||
en_US: Model type
|
||||
zh_Hans: 模型类型
|
||||
required: true
|
||||
options:
|
||||
- value: text-generation
|
||||
label:
|
||||
en_US: Language Model
|
||||
zh_Hans: 语言模型
|
||||
- value: embeddings
|
||||
label:
|
||||
en_US: Text Embedding
|
||||
- value: reranking
|
||||
label:
|
||||
en_US: Rerank
|
||||
```
|
||||
- 每一个模型都有自己的名称`model_name`,因此需要在这里定义
|
||||
```yaml
|
||||
- variable: model_name
|
||||
type: text-input
|
||||
label:
|
||||
en_US: Model name
|
||||
zh_Hans: 模型名称
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 填写模型名称
|
||||
en_US: Input model name
|
||||
```
|
||||
- 填写Xinference本地部署的地址
|
||||
```yaml
|
||||
- variable: server_url
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
|
||||
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
||||
```
|
||||
- 每个模型都有唯一的model_uid,因此需要在这里定义
|
||||
```yaml
|
||||
- variable: model_uid
|
||||
label:
|
||||
zh_Hans: 模型UID
|
||||
en_US: Model uid
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的Model UID
|
||||
en_US: Enter the model uid
|
||||
```
|
||||
现在,我们就完成了供应商的基础定义。
|
||||
|
||||
### 编写模型代码
|
||||
|
||||
然后我们以`llm`类型为例,编写`xinference.llm.llm.py`
|
||||
|
||||
在 `llm.py` 中创建一个 Xinference LLM 类,我们取名为 `XinferenceAILargeLanguageModel`(随意),继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下几个方法:
|
||||
|
||||
- LLM 调用
|
||||
|
||||
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
|
||||
|
||||
```python
|
||||
def _invoke(self, stream: bool, **kwargs) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
if stream:
|
||||
return self._handle_stream_response(**kwargs)
|
||||
return self._handle_sync_response(**kwargs)
|
||||
|
||||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||
return LLMResult(**response)
|
||||
```
|
||||
|
||||
- 预计算输入 tokens
|
||||
|
||||
若模型未提供预计算 tokens 接口,可直接返回 0。
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
有时候,也许你不需要直接返回0,所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的tokens,这个方法位于`AIModel`基类中,它会使用GPT2的Tokenizer进行计算,但是只能作为替代方法,并不完全准确。
|
||||
|
||||
- 模型凭据校验
|
||||
|
||||
与供应商凭据校验类似,这里针对单个模型进行校验。
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- 模型参数Schema
|
||||
|
||||
与自定义类型不同,由于没有在yaml文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的Schema。
|
||||
|
||||
如Xinference支持`max_tokens` `temperature` `top_p` 这三个模型参数。
|
||||
|
||||
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例A模型支持`top_k`,B模型不支持`top_k`,那么我们需要在这里动态生成模型参数的Schema,如下所示:
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature', type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度', en_US='Temperature'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p', type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P', en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens', type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度', en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# if model is A, add top_k to rules
|
||||
if model == 'A':
|
||||
rules.append(
|
||||
ParameterRule(
|
||||
name='top_k', type=ParameterType.INT,
|
||||
use_template='top_k',
|
||||
min=1,
|
||||
default=50,
|
||||
label=I18nObject(
|
||||
zh_Hans='Top K', en_US='Top K'
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
some NOT IMPORTANT code here
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=model_type,
|
||||
model_properties={
|
||||
'mode': ModelType.LLM,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
return entity
|
||||
```
|
||||
|
||||
- 调用异常错误映射表
|
||||
|
||||
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` 调用连接错误
|
||||
- `InvokeServerUnavailableError ` 调用服务方不可用
|
||||
- `InvokeRateLimitError ` 调用达到限额
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeBadRequestError ` 调用传参有误
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。
|
||||
BIN
api/core/model_runtime/docs/zh_Hans/images/index/image-1.png
Normal file
|
After Width: | Height: | Size: 230 KiB |
BIN
api/core/model_runtime/docs/zh_Hans/images/index/image-2.png
Normal file
|
After Width: | Height: | Size: 205 KiB |
|
After Width: | Height: | Size: 385 KiB |
|
After Width: | Height: | Size: 113 KiB |
|
After Width: | Height: | Size: 109 KiB |
|
After Width: | Height: | Size: 70 KiB |
|
After Width: | Height: | Size: 75 KiB |
|
After Width: | Height: | Size: 541 KiB |
BIN
api/core/model_runtime/docs/zh_Hans/images/index/image-3.png
Normal file
|
After Width: | Height: | Size: 44 KiB |
BIN
api/core/model_runtime/docs/zh_Hans/images/index/image.png
Normal file
|
After Width: | Height: | Size: 262 KiB |
706
api/core/model_runtime/docs/zh_Hans/interfaces.md
Normal file
@ -0,0 +1,706 @@
|
||||
# 接口方法
|
||||
|
||||
这里介绍供应商和各模型类型需要实现的接口方法和参数说明。
|
||||
|
||||
## 供应商
|
||||
|
||||
继承 `__base.model_provider.ModelProvider` 基类,实现以下接口:
|
||||
|
||||
```python
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
```
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
验证失败请抛出 `errors.validate.CredentialsValidateFailedError` 错误。
|
||||
|
||||
**注:预定义模型需完整实现该接口,自定义模型供应商只需要如下简单实现即可**
|
||||
|
||||
```python
|
||||
class XinferenceProvider(Provider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
```
|
||||
|
||||
## 模型
|
||||
|
||||
模型分为 5 种不同的模型类型,不同模型类型继承的基类不同,需要实现的方法也不同。
|
||||
|
||||
### 通用接口
|
||||
|
||||
所有模型均需要统一实现下面 2 个方法:
|
||||
|
||||
- 模型凭据校验
|
||||
|
||||
与供应商凭据校验类似,这里针对单个模型进行校验。
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
验证失败请抛出 `errors.validate.CredentialsValidateFailedError` 错误。
|
||||
|
||||
- 调用异常错误映射表
|
||||
|
||||
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` 调用连接错误
|
||||
- `InvokeServerUnavailableError ` 调用服务方不可用
|
||||
- `InvokeRateLimitError ` 调用达到限额
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeBadRequestError ` 调用传参有误
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
也可以直接抛出对应Erros,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
可参考 OpenAI `_invoke_error_mapping`。
|
||||
|
||||
### LLM
|
||||
|
||||
继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下接口:
|
||||
|
||||
- LLM 调用
|
||||
|
||||
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `prompt_messages` (array[[PromptMessage](#PromptMessage)]) Prompt 列表
|
||||
|
||||
若模型为 `Completion` 类型,则列表只需要传入一个 [UserPromptMessage](#UserPromptMessage) 元素即可;
|
||||
|
||||
若模型为 `Chat` 类型,需要根据消息不同传入 [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) 元素列表
|
||||
|
||||
- `model_parameters` (object) 模型参数
|
||||
|
||||
模型参数由模型 YAML 配置的 `parameter_rules` 定义。
|
||||
|
||||
- `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] 工具列表,等同于 `function calling` 中的 `function`。
|
||||
|
||||
即传入 tool calling 的工具列表。
|
||||
|
||||
- `stop` (array[string]) [optional] 停止序列
|
||||
|
||||
模型返回将在停止序列定义的字符串之前停止输出。
|
||||
|
||||
- `stream` (bool) 是否流式输出,默认 True
|
||||
|
||||
流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回
|
||||
|
||||
流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。
|
||||
|
||||
- 预计算输入 tokens
|
||||
|
||||
若模型未提供预计算 tokens 接口,可直接返回 0。
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
参数说明见上述 `LLM 调用`。
|
||||
|
||||
该接口需要根据对应`model`选择合适的`tokenizer`进行计算,如果对应模型没有提供`tokenizer`,可以使用`AIModel`基类中的`_get_num_tokens_by_gpt2(text: str)`方法进行计算。
|
||||
|
||||
- 获取自定义模型规则 [可选]
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
```
|
||||
|
||||
当供应商支持增加自定义 LLM 时,可实现此方法让自定义模型可获取模型规则,默认返回 None。
|
||||
|
||||
对于`OpenAI`供应商下的大部分微调模型,可以通过其微调模型名称获取到其基类模型,如`gpt-3.5-turbo-1106`,然后返回基类模型的预定义参数规则,参考[openai](https://github.com/langgenius/dify/blob/feat/model-runtime/api/core/model_runtime/model_providers/openai/llm/llm.py#L801)
|
||||
的具体实现
|
||||
|
||||
### TextEmbedding
|
||||
|
||||
继承 `__base.text_embedding_model.TextEmbeddingModel` 基类,实现以下接口:
|
||||
|
||||
- Embedding 调用
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `texts` (array[string]) 文本列表,可批量处理
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
[TextEmbeddingResult](#TextEmbeddingResult) 实体。
|
||||
|
||||
- 预计算 tokens
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
参数说明见上述 `Embedding 调用`。
|
||||
|
||||
同上述`LargeLanguageModel`,该接口需要根据对应`model`选择合适的`tokenizer`进行计算,如果对应模型没有提供`tokenizer`,可以使用`AIModel`基类中的`_get_num_tokens_by_gpt2(text: str)`方法进行计算。
|
||||
|
||||
### Rerank
|
||||
|
||||
继承 `__base.rerank_model.RerankModel` 基类,实现以下接口:
|
||||
|
||||
- rerank 调用
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `query` (string) 查询请求内容
|
||||
|
||||
- `docs` (array[string]) 需要重排的分段列表
|
||||
|
||||
- `score_threshold` (float) [optional] Score 阈值
|
||||
|
||||
- `top_n` (int) [optional] 取前 n 个分段
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
[RerankResult](#RerankResult) 实体。
|
||||
|
||||
### Speech2text
|
||||
|
||||
继承 `__base.speech2text_model.Speech2TextModel` 基类,实现以下接口:
|
||||
|
||||
- Invoke 调用
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `file` (File) 文件流
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
语音转换后的字符串。
|
||||
|
||||
### Moderation
|
||||
|
||||
继承 `__base.moderation_model.ModerationModel` 基类,实现以下接口:
|
||||
|
||||
- Invoke 调用
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `text` (string) 文本内容
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
False 代表传入的文本安全,True 则反之。
|
||||
|
||||
|
||||
|
||||
## 实体
|
||||
|
||||
### PromptMessageRole
|
||||
|
||||
消息角色
|
||||
|
||||
```python
|
||||
class PromptMessageRole(Enum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
```
|
||||
|
||||
### PromptMessageContentType
|
||||
|
||||
消息内容类型,分为纯文本和图片。
|
||||
|
||||
```python
|
||||
class PromptMessageContentType(Enum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
TEXT = 'text'
|
||||
IMAGE = 'image'
|
||||
```
|
||||
|
||||
### PromptMessageContent
|
||||
|
||||
消息内容基类,仅作为参数声明用,不可初始化。
|
||||
|
||||
```python
|
||||
class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType
|
||||
data: str # 内容数据
|
||||
```
|
||||
|
||||
当前支持文本和图片两种类型,可支持同时传入文本和多图。
|
||||
|
||||
需要分别初始化 `TextPromptMessageContent` 和 `ImagePromptMessageContent` 传入。
|
||||
|
||||
### TextPromptMessageContent
|
||||
|
||||
```python
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
```
|
||||
|
||||
若传入图文,其中文字需要构造此实体作为 `content` 列表中的一部分。
|
||||
|
||||
### ImagePromptMessageContent
|
||||
|
||||
```python
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
class DETAIL(Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW # 分辨率
|
||||
```
|
||||
|
||||
若传入图文,其中图片需要构造此实体作为 `content` 列表中的一部分
|
||||
|
||||
`data` 可以为 `url` 或者图片 `base64` 加密后的字符串。
|
||||
|
||||
### PromptMessage
|
||||
|
||||
所有 Role 消息体的基类,仅作为参数声明用,不可初始化。
|
||||
|
||||
```python
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
role: PromptMessageRole # 消息角色
|
||||
content: Optional[str | list[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
|
||||
name: Optional[str] = None # 名称,可选。
|
||||
```
|
||||
|
||||
### UserPromptMessage
|
||||
|
||||
UserMessage 消息体,代表用户消息。
|
||||
|
||||
```python
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for user prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.USER
|
||||
```
|
||||
|
||||
### AssistantPromptMessage
|
||||
|
||||
代表模型返回消息,通常用于 `few-shots` 或聊天历史传入。
|
||||
|
||||
```python
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for assistant prompt message.
|
||||
"""
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call.
|
||||
"""
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call function.
|
||||
"""
|
||||
name: str # 工具名称
|
||||
arguments: str # 工具参数
|
||||
|
||||
id: str # 工具 ID,仅在 OpenAI tool call 生效,为工具调用的唯一 ID,同一个工具可以调用多次
|
||||
type: str # 默认 function
|
||||
function: ToolCallFunction # 工具调用信息
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools,并且模型认为需要调用工具时返回)
|
||||
```
|
||||
|
||||
其中 `tool_calls` 为调用模型传入 `tools` 后,由模型返回的 `tool call` 列表。
|
||||
|
||||
### SystemPromptMessage
|
||||
|
||||
代表系统消息,通常用于设定给模型的系统指令。
|
||||
|
||||
```python
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for system prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||
```
|
||||
|
||||
### ToolPromptMessage
|
||||
|
||||
代表工具消息,用于工具执行后将结果交给模型进行下一步计划。
|
||||
|
||||
```python
|
||||
class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for tool prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str # 工具调用 ID,若不支持 OpenAI tool call,也可传入工具名称
|
||||
```
|
||||
|
||||
基类的 `content` 传入工具执行结果。
|
||||
|
||||
### PromptMessageTool
|
||||
|
||||
```python
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
Model class for prompt message tool.
|
||||
"""
|
||||
name: str # 工具名称
|
||||
description: str # 工具描述
|
||||
parameters: dict # 工具参数 dict
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### LLMResult
|
||||
|
||||
```python
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
Model class for llm result.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
prompt_messages: list[PromptMessage] # prompt 消息列表
|
||||
message: AssistantPromptMessage # 回复消息
|
||||
usage: LLMUsage # 使用的 tokens 及费用信息
|
||||
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
|
||||
```
|
||||
|
||||
### LLMResultChunkDelta
|
||||
|
||||
流式返回中每个迭代内部 `delta` 实体
|
||||
|
||||
```python
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
"""
|
||||
index: int # 序号
|
||||
message: AssistantPromptMessage # 回复消息
|
||||
usage: Optional[LLMUsage] = None # 使用的 tokens 及费用信息,仅最后一条返回
|
||||
finish_reason: Optional[str] = None # 结束原因,仅最后一条返回
|
||||
```
|
||||
|
||||
### LLMResultChunk
|
||||
|
||||
流式返回中每个迭代实体
|
||||
|
||||
```python
|
||||
class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
prompt_messages: list[PromptMessage] # prompt 消息列表
|
||||
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
|
||||
delta: LLMResultChunkDelta # 每个迭代存在变化的内容
|
||||
```
|
||||
|
||||
### LLMUsage
|
||||
|
||||
```python
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for llm usage.
|
||||
"""
|
||||
prompt_tokens: int # prompt 使用 tokens
|
||||
prompt_unit_price: Decimal # prompt 单价
|
||||
prompt_price_unit: Decimal # prompt 价格单位,即单价基于多少 tokens
|
||||
prompt_price: Decimal # prompt 费用
|
||||
completion_tokens: int # 回复使用 tokens
|
||||
completion_unit_price: Decimal # 回复单价
|
||||
completion_price_unit: Decimal # 回复价格单位,即单价基于多少 tokens
|
||||
completion_price: Decimal # 回复费用
|
||||
total_tokens: int # 总使用 token 数
|
||||
total_price: Decimal # 总费用
|
||||
currency: str # 货币单位
|
||||
latency: float # 请求耗时(s)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### TextEmbeddingResult
|
||||
|
||||
```python
|
||||
class TextEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
embeddings: list[list[float]] # embedding 向量列表,对应传入的 texts 列表
|
||||
usage: EmbeddingUsage # 使用信息
|
||||
```
|
||||
|
||||
### EmbeddingUsage
|
||||
|
||||
```python
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
"""
|
||||
tokens: int # 使用 token 数
|
||||
total_tokens: int # 总使用 token 数
|
||||
unit_price: Decimal # 单价
|
||||
price_unit: Decimal # 价格单位,即单价基于多少 tokens
|
||||
total_price: Decimal # 总费用
|
||||
currency: str # 货币单位
|
||||
latency: float # 请求耗时(s)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### RerankResult
|
||||
|
||||
```python
|
||||
class RerankResult(BaseModel):
|
||||
"""
|
||||
Model class for rerank result.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
docs: list[RerankDocument] # 重排后的分段列表
|
||||
```
|
||||
|
||||
### RerankDocument
|
||||
|
||||
```python
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
"""
|
||||
index: int # 原序号
|
||||
text: str # 分段文本内容
|
||||
score: float # 分数
|
||||
```
|
||||
@ -0,0 +1,171 @@
|
||||
## 预定义模型接入
|
||||
|
||||
供应商集成完成后,接下来为供应商下模型的接入。
|
||||
|
||||
我们首先需要确定接入模型的类型,并在对应供应商的目录下创建对应模型类型的 `module`。
|
||||
|
||||
当前支持模型类型如下:
|
||||
|
||||
- `llm` 文本生成模型
|
||||
- `text_embedding` 文本 Embedding 模型
|
||||
- `rerank` Rerank 模型
|
||||
- `speech2text` 语音转文字
|
||||
- `moderation` 审查
|
||||
|
||||
依旧以 `Anthropic` 为例,`Anthropic` 仅支持 LLM,因此在 `model_providers.anthropic` 创建一个 `llm` 为名称的 `module`。
|
||||
|
||||
对于预定义的模型,我们首先需要在 `llm` `module` 下创建以模型名为文件名称的 YAML 文件,如:`claude-2.1.yaml`。
|
||||
|
||||
### 准备模型 YAML
|
||||
|
||||
```yaml
|
||||
model: claude-2.1 # 模型标识
|
||||
# 模型展示名称,可设置 en_US 英文、zh_Hans 中文两种语言,zh_Hans 不设置将默认使用 en_US。
|
||||
# 也可不设置 label,则使用 model 标识内容。
|
||||
label:
|
||||
en_US: claude-2.1
|
||||
model_type: llm # 模型类型,claude-2.1 为 LLM
|
||||
features: # 支持功能,agent-thought 为支持 Agent 推理,vision 为支持图片理解
|
||||
- agent-thought
|
||||
model_properties: # 模型属性
|
||||
mode: chat # LLM 模式,complete 文本补全模型,chat 对话模型
|
||||
context_size: 200000 # 支持最大上下文大小
|
||||
parameter_rules: # 模型调用参数规则,仅 LLM 需要提供
|
||||
- name: temperature # 调用参数变量名
|
||||
# 默认预置了 5 种变量内容配置模板,temperature/top_p/max_tokens/presence_penalty/frequency_penalty
|
||||
# 可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
|
||||
# 若设置了额外的配置参数,将覆盖默认配置
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label: # 调用参数展示名称
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int # 参数类型,支持 float/int/string/boolean
|
||||
help: # 帮助信息,描述参数作用
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false # 是否必填,可不设置
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
default: 4096 # 参数默认值
|
||||
min: 1 # 参数最小值,仅 float/int 可用
|
||||
max: 4096 # 参数最大值,仅 float/int 可用
|
||||
pricing: # 价格信息
|
||||
input: '8.00' # 输入单价,即 Prompt 单价
|
||||
output: '24.00' # 输出单价,即返回内容单价
|
||||
unit: '0.000001' # 价格单位,即上述价格为每 100K 的单价
|
||||
currency: USD # 价格货币
|
||||
```
|
||||
|
||||
建议将所有模型配置都准备完毕后再开始模型代码的实现。
|
||||
|
||||
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#AIModel)。
|
||||
|
||||
### 实现模型调用代码
|
||||
|
||||
接下来需要在 `llm` `module` 下创建一个同名的 python 文件 `llm.py` 来编写代码实现。
|
||||
|
||||
在 `llm.py` 中创建一个 Anthropic LLM 类,我们取名为 `AnthropicLargeLanguageModel`(随意),继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下几个方法:
|
||||
|
||||
- LLM 调用
|
||||
|
||||
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
|
||||
|
||||
```python
|
||||
def _invoke(self, stream: bool, **kwargs) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
if stream:
|
||||
return self._handle_stream_response(**kwargs)
|
||||
return self._handle_sync_response(**kwargs)
|
||||
|
||||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||
return LLMResult(**response)
|
||||
```
|
||||
|
||||
- 预计算输入 tokens
|
||||
|
||||
若模型未提供预计算 tokens 接口,可直接返回 0。
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- 模型凭据校验
|
||||
|
||||
与供应商凭据校验类似,这里针对单个模型进行校验。
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- 调用异常错误映射表
|
||||
|
||||
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` 调用连接错误
|
||||
- `InvokeServerUnavailableError ` 调用服务方不可用
|
||||
- `InvokeRateLimitError ` 调用达到限额
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeBadRequestError ` 调用传参有误
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。
|
||||
188
api/core/model_runtime/docs/zh_Hans/provider_scale_out.md
Normal file
@ -0,0 +1,188 @@
|
||||
## 增加新供应商
|
||||
|
||||
供应商支持三种模型配置方式:
|
||||
|
||||
- `predefined-model ` 预定义模型
|
||||
|
||||
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
|
||||
|
||||
- `customizable-model` 自定义模型
|
||||
|
||||
用户需要新增每个模型的凭据配置,如Xinference,它同时支持 LLM 和 Text Embedding,但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
|
||||
|
||||
- `fetch-from-remote` 从远程获取
|
||||
|
||||
与 `predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
|
||||
|
||||
如OpenAI,我们可以基于gpt-turbo-3.5来Fine Tune多个模型,而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让DifyRuntime获取到开发者所有的微调模型并接入Dify。
|
||||
|
||||
这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model` 或 `predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。
|
||||
|
||||
## 开始
|
||||
|
||||
### 介绍
|
||||
|
||||
#### 名词解释
|
||||
- `module`: 一个`module`即为一个Python Package,或者通俗一点,称为一个文件夹,里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
|
||||
|
||||
#### 步骤
|
||||
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
|
||||
|
||||
- 创建供应商yaml文件,根据[ProviderSchema](./schema.md#provider)编写
|
||||
- 创建供应商代码,实现一个`class`。
|
||||
- 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm`或`text_embedding`。
|
||||
- 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`。
|
||||
- 如果有预定义模型,根据模型名称创建同名的yaml文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。
|
||||
- 编写测试代码,确保功能可用。
|
||||
|
||||
### 开始吧
|
||||
|
||||
增加一个新的供应商需要先确定供应商的英文标识,如 `anthropic`,使用该标识在 `model_providers` 创建以此为名称的 `module`。
|
||||
|
||||
在此 `module` 下,我们需要先准备供应商的 YAML 配置。
|
||||
|
||||
#### 准备供应商 YAML
|
||||
|
||||
此处以 `Anthropic` 为例,预设了供应商基础信息、支持的模型类型、配置方式、凭据规则。
|
||||
|
||||
```YAML
|
||||
provider: anthropic # 供应商标识
|
||||
label: # 供应商展示名称,可设置 en_US 英文、zh_Hans 中文两种语言,zh_Hans 不设置将默认使用 en_US。
|
||||
en_US: Anthropic
|
||||
icon_small: # 供应商小图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
|
||||
en_US: icon_s_en.png
|
||||
icon_large: # 供应商大图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
|
||||
en_US: icon_l_en.png
|
||||
supported_model_types: # 支持的模型类型,Anthropic 仅支持 LLM
|
||||
- llm
|
||||
configurate_methods: # 支持的配置方式,Anthropic 仅支持预定义模型
|
||||
- predefined-model
|
||||
provider_credential_schema: # 供应商凭据规则,由于 Anthropic 仅支持预定义模型,则需要定义统一供应商凭据规则
|
||||
credential_form_schemas: # 凭据表单项列表
|
||||
- variable: anthropic_api_key # 凭据参数变量名
|
||||
label: # 展示名称
|
||||
en_US: API Key
|
||||
type: secret-input # 表单类型,此处 secret-input 代表加密信息输入框,编辑时只展示屏蔽后的信息。
|
||||
required: true # 是否必填
|
||||
placeholder: # PlaceHolder 信息
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: anthropic_api_url
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input # 表单类型,此处 text-input 代表文本输入框
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API URL
|
||||
en_US: Enter your API URL
|
||||
```
|
||||
|
||||
如果接入的供应商提供自定义模型,比如`OpenAI`提供微调模型,那么我们就需要添加[`model_credential_schema`](./schema.md#modelcredentialschema),以`OpenAI`为例:
|
||||
|
||||
```yaml
|
||||
model_credential_schema:
|
||||
model: # 微调模型名称
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: openai_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: openai_organization
|
||||
label:
|
||||
zh_Hans: 组织 ID
|
||||
en_US: Organization
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的组织 ID
|
||||
en_US: Enter your Organization ID
|
||||
- variable: openai_api_base
|
||||
label:
|
||||
zh_Hans: API Base
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
||||
```
|
||||
|
||||
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#Provider)。
|
||||
|
||||
#### 实现供应商代码
|
||||
|
||||
我们需要在`model_providers`下创建一个同名的python文件,如`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`。
|
||||
|
||||
##### 自定义模型供应商
|
||||
|
||||
当供应商为Xinference等自定义模型供应商时,可跳过该步骤,仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。
|
||||
|
||||
```python
|
||||
class XinferenceProvider(Provider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
```
|
||||
|
||||
##### 预定义模型供应商
|
||||
|
||||
供应商需要继承 `__base.model_provider.ModelProvider` 基类,实现 `validate_provider_credentials` 供应商统一凭据校验方法即可,可参考 [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py)。
|
||||
|
||||
```python
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
```
|
||||
|
||||
当然也可以先预留 `validate_provider_credentials` 实现,在模型凭据校验方法实现后直接复用。
|
||||
|
||||
#### 增加模型
|
||||
|
||||
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
|
||||
对于预定义模型,我们可以通过简单定义一个yaml,并通过实现调用代码来接入。
|
||||
|
||||
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
|
||||
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。
|
||||
|
||||
---
|
||||
|
||||
### 测试
|
||||
|
||||
为了保证接入供应商/模型的可用性,编写后的每个方法均需要在 `tests` 目录中编写对应的集成测试代码。
|
||||
|
||||
依旧以 `Anthropic` 为例。
|
||||
|
||||
在编写测试代码前,需要先在 `.env.example` 新增测试供应商所需要的凭据环境变量,如:`ANTHROPIC_API_KEY`。
|
||||
|
||||
在执行前需要将 `.env.example` 复制为 `.env` 再执行。
|
||||
|
||||
#### 编写测试代码
|
||||
|
||||
在 `tests` 目录下创建供应商同名的 `module`: `anthropic`,继续在此模块中创建 `test_provider.py` 以及对应模型类型的 test py 文件,如下所示:
|
||||
|
||||
```shell
|
||||
.
|
||||
├── __init__.py
|
||||
├── anthropic
|
||||
│ ├── __init__.py
|
||||
│ ├── test_llm.py # LLM 测试
|
||||
│ └── test_provider.py # 供应商测试
|
||||
```
|
||||
|
||||
针对上面实现的代码的各种情况进行测试代码编写,并测试通过后提交代码。
|
||||
196
api/core/model_runtime/docs/zh_Hans/schema.md
Normal file
@ -0,0 +1,196 @@
|
||||
# 配置规则
|
||||
|
||||
- 供应商规则基于 [Provider](#Provider) 实体。
|
||||
|
||||
- 模型规则基于 [AIModelEntity](#AIModelEntity) 实体。
|
||||
|
||||
> 以下所有实体均基于 `Pydantic BaseModel`,可在 `entities` 模块中找到对应实体。
|
||||
|
||||
### Provider
|
||||
|
||||
- `provider` (string) 供应商标识,如:`openai`
|
||||
- `label` (object) 供应商展示名称,i18n,可设置 `en_US` 英文、`zh_Hans` 中文两种语言
|
||||
- `zh_Hans ` (string) [optional] 中文标签名,`zh_Hans` 不设置将默认使用 `en_US`。
|
||||
- `en_US` (string) 英文标签名
|
||||
- `description` (object) [optional] 供应商描述,i18n
|
||||
- `zh_Hans` (string) [optional] 中文描述
|
||||
- `en_US` (string) 英文描述
|
||||
- `icon_small` (string) [optional] 供应商小 ICON,存储在对应供应商实现目录下的 `_assets` 目录,中英文策略同 `label`
|
||||
- `zh_Hans` (string) [optional] 中文 ICON
|
||||
- `en_US` (string) 英文 ICON
|
||||
- `icon_large` (string) [optional] 供应商大 ICON,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
|
||||
- `zh_Hans `(string) [optional] 中文 ICON
|
||||
- `en_US` (string) 英文 ICON
|
||||
- `background` (string) [optional] 背景颜色色值,例:#FFFFFF,为空则展示前端默认色值。
|
||||
- `help` (object) [optional] 帮助信息
|
||||
- `title` (object) 帮助标题,i18n
|
||||
- `zh_Hans` (string) [optional] 中文标题
|
||||
- `en_US` (string) 英文标题
|
||||
- `url` (object) 帮助链接,i18n
|
||||
- `zh_Hans` (string) [optional] 中文链接
|
||||
- `en_US` (string) 英文链接
|
||||
- `supported_model_types` (array[[ModelType](#ModelType)]) 支持的模型类型
|
||||
- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) 配置方式
|
||||
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) 供应商凭据规格
|
||||
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) 模型凭据规格
|
||||
|
||||
### AIModelEntity
|
||||
|
||||
- `model` (string) 模型标识,如:`gpt-3.5-turbo`
|
||||
- `label` (object) [optional] 模型展示名称,i18n,可设置 `en_US` 英文、`zh_Hans` 中文两种语言
|
||||
- `zh_Hans `(string) [optional] 中文标签名
|
||||
- `en_US` (string) 英文标签名
|
||||
- `model_type` ([ModelType](#ModelType)) 模型类型
|
||||
- `features` (array[[ModelFeature](#ModelFeature)]) [optional] 支持功能列表
|
||||
- `model_properties` (object) 模型属性
|
||||
- `mode` ([LLMMode](#LLMMode)) 模式 (模型类型 `llm` 可用)
|
||||
- `context_size` (int) 上下文大小 (模型类型 `llm` `text-embedding` 可用)
|
||||
- `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用)
|
||||
- `file_upload_limit` (int) 文件最大上传限制,单位:MB。(模型类型 `speech2text` 可用)
|
||||
- `supported_file_extensions` (string) 支持文件扩展格式,如:mp3,mp4(模型类型 `speech2text` 可用)
|
||||
- `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用)
|
||||
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] 模型调用参数规则
|
||||
- `pricing` ([PriceConfig](#PriceConfig)) [optional] 价格信息
|
||||
- `deprecated` (bool) 是否废弃。若废弃,模型列表将不再展示,但已经配置的可以继续使用,默认 False。
|
||||
|
||||
### ModelType
|
||||
|
||||
- `llm` 文本生成模型
|
||||
- `text-embedding` 文本 Embedding 模型
|
||||
- `rerank` Rerank 模型
|
||||
- `speech2text` 语音转文字
|
||||
- `moderation` 审查
|
||||
|
||||
### ConfigurateMethod
|
||||
|
||||
- `predefined-model ` 预定义模型
|
||||
|
||||
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
|
||||
- `customizable-model` 自定义模型
|
||||
|
||||
用户需要新增每个模型的凭据配置。
|
||||
|
||||
- `fetch-from-remote` 从远程获取
|
||||
|
||||
与 `predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
|
||||
|
||||
### ModelFeature
|
||||
|
||||
- `agent-thought` Agent 推理,一般超过 70B 有思维链能力。
|
||||
- `vision` 视觉,即:图像理解。
|
||||
|
||||
### FetchFrom
|
||||
|
||||
- `predefined-model` 预定义模型
|
||||
- `fetch-from-remote` 远程模型
|
||||
|
||||
### LLMMode
|
||||
|
||||
- `completion` 文本补全
|
||||
- `chat` 对话
|
||||
|
||||
### ParameterRule
|
||||
|
||||
- `name` (string) 调用模型实际参数名
|
||||
|
||||
- `use_template` (string) [optional] 使用模板
|
||||
|
||||
默认预置了 5 种变量内容配置模板:
|
||||
|
||||
- `temperature`
|
||||
- `top_p`
|
||||
- `frequency_penalty`
|
||||
- `presence_penalty`
|
||||
- `max_tokens`
|
||||
|
||||
可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
|
||||
不用设置除 `name` 和 `use_template` 之外的所有参数,若设置了额外的配置参数,将覆盖默认配置。
|
||||
可参考 `openai/llm/gpt-3.5-turbo.yaml`。
|
||||
|
||||
- `label` (object) [optional] 标签,i18n
|
||||
|
||||
- `zh_Hans`(string) [optional] 中文标签名
|
||||
- `en_US` (string) 英文标签名
|
||||
|
||||
- `type`(string) [optional] 参数类型
|
||||
|
||||
- `int` 整数
|
||||
- `float` 浮点数
|
||||
- `string` 字符串
|
||||
- `boolean` 布尔型
|
||||
|
||||
- `help` (string) [optional] 帮助信息
|
||||
|
||||
- `zh_Hans` (string) [optional] 中文帮助信息
|
||||
- `en_US` (string) 英文帮助信息
|
||||
|
||||
- `required` (bool) 是否必填,默认 False。
|
||||
|
||||
- `default`(int/float/string/bool) [optional] 默认值
|
||||
|
||||
- `min`(int/float) [optional] 最小值,仅数字类型适用
|
||||
|
||||
- `max`(int/float) [optional] 最大值,仅数字类型适用
|
||||
|
||||
- `precision`(int) [optional] 精度,保留小数位数,仅数字类型适用
|
||||
|
||||
- `options` (array[string]) [optional] 下拉选项值,仅当 `type` 为 `string` 时适用,若不设置或为 null 则不限制选项值
|
||||
|
||||
### PriceConfig
|
||||
|
||||
- `input` (float) 输入单价,即 Prompt 单价
|
||||
- `output` (float) 输出单价,即返回内容单价
|
||||
- `unit` (float) 价格单位,如:每 100K 的单价为 `0.000001`
|
||||
- `currency` (string) 货币单位
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
||||
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范
|
||||
|
||||
### ModelCredentialSchema
|
||||
|
||||
- `model` (object) 模型标识,变量名默认 `model`
|
||||
- `label` (object) 模型表单项展示名称
|
||||
- `en_US` (string) 英文
|
||||
- `zh_Hans`(string) [optional] 中文
|
||||
- `placeholder` (object) 模型提示内容
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans`(string) [optional] 中文
|
||||
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范
|
||||
|
||||
### CredentialFormSchema
|
||||
|
||||
- `variable` (string) 表单项变量名
|
||||
- `label` (object) 表单项标签名
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans` (string) [optional] 中文
|
||||
- `type` ([FormType](#FormType)) 表单项类型
|
||||
- `required` (bool) 是否必填
|
||||
- `default`(string) 默认值
|
||||
- `options` (array[[FormOption](#FormOption)]) 表单项为 `select` 或 `radio` 专有属性,定义下拉内容
|
||||
- `placeholder`(object) 表单项为 `text-input `专有属性,表单项 PlaceHolder
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans` (string) [optional] 中文
|
||||
- `max_length` (int) 表单项为`text-input`专有属性,定义输入最大长度,0 为不限制。
|
||||
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。
|
||||
|
||||
### FormType
|
||||
|
||||
- `text-input` 文本输入组件
|
||||
- `secret-input` 密码输入组件
|
||||
- `select` 单选下拉
|
||||
- `radio` Radio 组件
|
||||
- `switch` 开关组件,仅支持 `true` 和 `false`
|
||||
|
||||
### FormOption
|
||||
|
||||
- `label` (object) 标签
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans`(string) [optional] 中文
|
||||
- `value` (string) 下拉选项值
|
||||
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。
|
||||
|
||||
### FormShowOnObject
|
||||
|
||||
- `variable` (string) 其他表单项变量名
|
||||
- `value` (string) 其他表单项变量值
|
||||
0
api/core/model_runtime/entities/__init__.py
Normal file
16
api/core/model_runtime/entities/common_entities.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
zh_Hans: Optional[str] = None
|
||||
en_US: str
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
if not self.zh_Hans:
|
||||
self.zh_Hans = self.en_US
|
||||
87
api/core/model_runtime/entities/defaults.py
Normal file
@ -0,0 +1,87 @@
|
||||
from typing import Dict
|
||||
|
||||
from core.model_runtime.entities.model_entities import DefaultParameterName
|
||||
|
||||
|
||||
PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
|
||||
DefaultParameterName.TEMPERATURE: {
|
||||
'label': {
|
||||
'en_US': 'Temperature',
|
||||
'zh_Hans': '温度',
|
||||
},
|
||||
'type': 'float',
|
||||
'help': {
|
||||
'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.',
|
||||
'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。',
|
||||
},
|
||||
'required': False,
|
||||
'default': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 1,
|
||||
},
|
||||
DefaultParameterName.TOP_P: {
|
||||
'label': {
|
||||
'en_US': 'Top P',
|
||||
'zh_Hans': 'Top P',
|
||||
},
|
||||
'type': 'float',
|
||||
'help': {
|
||||
'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.',
|
||||
'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。',
|
||||
},
|
||||
'required': False,
|
||||
'default': 1.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 1,
|
||||
},
|
||||
DefaultParameterName.PRESENCE_PENALTY: {
|
||||
'label': {
|
||||
'en_US': 'Presence Penalty',
|
||||
'zh_Hans': '存在惩罚',
|
||||
},
|
||||
'type': 'float',
|
||||
'help': {
|
||||
'en_US': 'Applies a penalty to the log-probability of tokens already in the text.',
|
||||
'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。',
|
||||
},
|
||||
'required': False,
|
||||
'default': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 1,
|
||||
},
|
||||
DefaultParameterName.FREQUENCY_PENALTY: {
|
||||
'label': {
|
||||
'en_US': 'Frequency Penalty',
|
||||
'zh_Hans': '频率惩罚',
|
||||
},
|
||||
'type': 'float',
|
||||
'help': {
|
||||
'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.',
|
||||
'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。',
|
||||
},
|
||||
'required': False,
|
||||
'default': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 1,
|
||||
},
|
||||
DefaultParameterName.MAX_TOKENS: {
|
||||
'label': {
|
||||
'en_US': 'Max Tokens',
|
||||
'zh_Hans': '最大标记',
|
||||
},
|
||||
'type': 'int',
|
||||
'help': {
|
||||
'en_US': 'The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.',
|
||||
'zh_Hans': '要生成的标记的最大数量。请求可以使用最多2048个标记,这些标记在提示和完成之间共享。',
|
||||
},
|
||||
'required': False,
|
||||
'default': 64,
|
||||
'min': 1,
|
||||
'max': 2048,
|
||||
'precision': 0,
|
||||
}
|
||||
}
|
||||
102
api/core/model_runtime/entities/llm_entities.py
Normal file
@ -0,0 +1,102 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
|
||||
|
||||
|
||||
class LLMMode(Enum):
|
||||
"""
|
||||
Enum class for large language model mode.
|
||||
"""
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'LLMMode':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for llm usage.
|
||||
"""
|
||||
prompt_tokens: int
|
||||
prompt_unit_price: Decimal
|
||||
prompt_price_unit: Decimal
|
||||
prompt_price: Decimal
|
||||
completion_tokens: int
|
||||
completion_unit_price: Decimal
|
||||
completion_price_unit: Decimal
|
||||
completion_price: Decimal
|
||||
total_tokens: int
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
|
||||
@classmethod
|
||||
def empty_usage(cls):
|
||||
return cls(
|
||||
prompt_tokens=0,
|
||||
prompt_unit_price=Decimal('0.0'),
|
||||
prompt_price_unit=Decimal('0.0'),
|
||||
prompt_price=Decimal('0.0'),
|
||||
completion_tokens=0,
|
||||
completion_unit_price=Decimal('0.0'),
|
||||
completion_price_unit=Decimal('0.0'),
|
||||
completion_price=Decimal('0.0'),
|
||||
total_tokens=0,
|
||||
total_price=Decimal('0.0'),
|
||||
currency='USD',
|
||||
latency=0.0
|
||||
)
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
Model class for llm result.
|
||||
"""
|
||||
model: str
|
||||
prompt_messages: list[PromptMessage]
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
"""
|
||||
index: int
|
||||
message: AssistantPromptMessage
|
||||
usage: Optional[LLMUsage] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
model: str
|
||||
prompt_messages: list[PromptMessage]
|
||||
system_fingerprint: Optional[str] = None
|
||||
delta: LLMResultChunkDelta
|
||||
|
||||
|
||||
class NumTokensResult(PriceInfo):
|
||||
"""
|
||||
Model class for number of tokens result.
|
||||
"""
|
||||
tokens: int
|
||||
134
api/core/model_runtime/entities/message_entities.py
Normal file
@ -0,0 +1,134 @@
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'PromptMessageRole':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid prompt message type value {value}')
|
||||
|
||||
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
Model class for prompt message tool.
|
||||
"""
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
|
||||
|
||||
class PromptMessageFunction(BaseModel):
|
||||
"""
|
||||
Model class for prompt message function.
|
||||
"""
|
||||
type: str = 'function'
|
||||
function: PromptMessageTool
|
||||
|
||||
|
||||
class PromptMessageContentType(Enum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
TEXT = 'text'
|
||||
IMAGE = 'image'
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType
|
||||
data: str
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
|
||||
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
class DETAIL(Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContent]] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for user prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.USER
|
||||
|
||||
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for assistant prompt message.
|
||||
"""
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call.
|
||||
"""
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call function.
|
||||
"""
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
id: str
|
||||
type: str
|
||||
function: ToolCallFunction
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for system prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||
|
||||
|
||||
class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for tool prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str
|
||||
196
api/core/model_runtime/entities/model_entities.py
Normal file
@ -0,0 +1,196 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
"""
|
||||
Enum class for model type.
|
||||
"""
|
||||
LLM = "llm"
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
# TTS = "tts"
|
||||
# TEXT2IMG = "text2img"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
"""
|
||||
Get model type from origin model type.
|
||||
|
||||
:return: model type
|
||||
"""
|
||||
if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value:
|
||||
return cls.LLM
|
||||
elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value:
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
|
||||
return cls.RERANK
|
||||
elif origin_model_type == cls.SPEECH2TEXT.value:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
raise ValueError(f'invalid origin model type {origin_model_type}')
|
||||
|
||||
def to_origin_model_type(self) -> str:
|
||||
"""
|
||||
Get origin model type from model type.
|
||||
|
||||
:return: origin model type
|
||||
"""
|
||||
if self == self.LLM:
|
||||
return 'text-generation'
|
||||
elif self == self.TEXT_EMBEDDING:
|
||||
return 'embeddings'
|
||||
elif self == self.RERANK:
|
||||
return 'reranking'
|
||||
elif self == self.SPEECH2TEXT:
|
||||
return 'speech2text'
|
||||
elif self == self.MODERATION:
|
||||
return 'moderation'
|
||||
else:
|
||||
raise ValueError(f'invalid model type {self}')
|
||||
|
||||
|
||||
class FetchFrom(Enum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
"""
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class ModelFeature(Enum):
|
||||
"""
|
||||
Enum class for llm feature.
|
||||
"""
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
|
||||
|
||||
class DefaultParameterName(Enum):
|
||||
"""
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
TEMPERATURE = "temperature"
|
||||
TOP_P = "top_p"
|
||||
PRESENCE_PENALTY = "presence_penalty"
|
||||
FREQUENCY_PENALTY = "frequency_penalty"
|
||||
MAX_TOKENS = "max_tokens"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
||||
"""
|
||||
Get parameter name from value.
|
||||
|
||||
:param value: parameter value
|
||||
:return: parameter name
|
||||
"""
|
||||
for name in cls:
|
||||
if name.value == value:
|
||||
return name
|
||||
raise ValueError(f'invalid parameter name {value}')
|
||||
|
||||
|
||||
class ParameterType(Enum):
|
||||
"""
|
||||
Enum class for parameter type.
|
||||
"""
|
||||
FLOAT = "float"
|
||||
INT = "int"
|
||||
STRING = "string"
|
||||
BOOLEAN = "boolean"
|
||||
|
||||
|
||||
class ModelPropertyKey(Enum):
|
||||
"""
|
||||
Enum class for model property key.
|
||||
"""
|
||||
MODE = "mode"
|
||||
CONTEXT_SIZE = "context_size"
|
||||
MAX_CHUNKS = "max_chunks"
|
||||
FILE_UPLOAD_LIMIT = "file_upload_limit"
|
||||
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
|
||||
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
|
||||
|
||||
|
||||
class ProviderModel(BaseModel):
|
||||
"""
|
||||
Model class for provider model.
|
||||
"""
|
||||
model: str
|
||||
label: I18nObject
|
||||
model_type: ModelType
|
||||
features: Optional[list[ModelFeature]] = None
|
||||
fetch_from: FetchFrom
|
||||
model_properties: dict[ModelPropertyKey, Any]
|
||||
deprecated: bool = False
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class ParameterRule(BaseModel):
|
||||
"""
|
||||
Model class for parameter rule.
|
||||
"""
|
||||
name: str
|
||||
use_template: Optional[str] = None
|
||||
label: I18nObject
|
||||
type: ParameterType
|
||||
help: Optional[I18nObject] = None
|
||||
required: bool = False
|
||||
default: Optional[Any] = None
|
||||
min: Optional[float | int] = None
|
||||
max: Optional[float | int] = None
|
||||
precision: Optional[int] = None
|
||||
options: list[str] = []
|
||||
|
||||
|
||||
class PriceConfig(BaseModel):
|
||||
"""
|
||||
Model class for pricing info.
|
||||
"""
|
||||
input: Decimal
|
||||
output: Optional[Decimal] = None
|
||||
unit: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class AIModelEntity(ProviderModel):
|
||||
"""
|
||||
Model class for AI model.
|
||||
"""
|
||||
parameter_rules: list[ParameterRule] = []
|
||||
pricing: Optional[PriceConfig] = None
|
||||
|
||||
|
||||
class ModelUsage(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class PriceType(Enum):
|
||||
"""
|
||||
Enum class for price type.
|
||||
"""
|
||||
INPUT = "input"
|
||||
OUTPUT = "output"
|
||||
|
||||
|
||||
class PriceInfo(BaseModel):
|
||||
"""
|
||||
Model class for price info.
|
||||
"""
|
||||
unit_price: Decimal
|
||||
unit: Decimal
|
||||
total_amount: Decimal
|
||||
currency: str
|
||||
149
api/core/model_runtime/entities/provider_entities.py
Normal file
@ -0,0 +1,149 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType, ProviderModel, AIModelEntity
|
||||
|
||||
|
||||
class ConfigurateMethod(Enum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class FormType(Enum):
|
||||
"""
|
||||
Enum class for form type.
|
||||
"""
|
||||
TEXT_INPUT = "text-input"
|
||||
SECRET_INPUT = "secret-input"
|
||||
SELECT = "select"
|
||||
RADIO = "radio"
|
||||
SWITCH = "switch"
|
||||
|
||||
|
||||
class FormShowOnObject(BaseModel):
|
||||
"""
|
||||
Model class for form show on.
|
||||
"""
|
||||
variable: str
|
||||
value: str
|
||||
|
||||
|
||||
class FormOption(BaseModel):
|
||||
"""
|
||||
Model class for form option.
|
||||
"""
|
||||
label: I18nObject
|
||||
value: str
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
if not self.label:
|
||||
self.label = I18nObject(
|
||||
en_US=self.value
|
||||
)
|
||||
|
||||
|
||||
class CredentialFormSchema(BaseModel):
|
||||
"""
|
||||
Model class for credential form schema.
|
||||
"""
|
||||
variable: str
|
||||
label: I18nObject
|
||||
type: FormType
|
||||
required: bool = True
|
||||
default: Optional[str] = None
|
||||
options: Optional[list[FormOption]] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
max_length: int = 0
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
|
||||
class ProviderCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for provider credential schema.
|
||||
"""
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
|
||||
class FieldModelSchema(BaseModel):
|
||||
label: I18nObject
|
||||
placeholder: Optional[I18nObject] = None
|
||||
|
||||
|
||||
class ModelCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for model credential schema.
|
||||
"""
|
||||
model: FieldModelSchema
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
|
||||
class SimpleProviderEntity(BaseModel):
|
||||
"""
|
||||
Simple model class for provider.
|
||||
"""
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
supported_model_types: list[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
|
||||
|
||||
class ProviderHelpEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider help.
|
||||
"""
|
||||
title: I18nObject
|
||||
url: I18nObject
|
||||
|
||||
|
||||
class ProviderEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider.
|
||||
"""
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
supported_model_types: list[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
models: list[ProviderModel] = []
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
def to_simple_provider(self) -> SimpleProviderEntity:
|
||||
"""
|
||||
Convert to simple provider.
|
||||
|
||||
:return: simple provider
|
||||
"""
|
||||
return SimpleProviderEntity(
|
||||
provider=self.provider,
|
||||
label=self.label,
|
||||
icon_small=self.icon_small,
|
||||
icon_large=self.icon_large,
|
||||
supported_model_types=self.supported_model_types,
|
||||
models=self.models
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""
|
||||
Model class for provider config.
|
||||
"""
|
||||
provider: str
|
||||
credentials: dict
|
||||
18
api/core/model_runtime/entities/rerank_entities.py
Normal file
@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
"""
|
||||
index: int
|
||||
text: str
|
||||
score: float
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
"""
|
||||
Model class for rerank result.
|
||||
"""
|
||||
model: str
|
||||
docs: list[RerankDocument]
|
||||
28
api/core/model_runtime/entities/text_embedding_entities.py
Normal file
@ -0,0 +1,28 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelUsage
|
||||
|
||||
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
"""
|
||||
tokens: int
|
||||
total_tokens: int
|
||||
unit_price: Decimal
|
||||
price_unit: Decimal
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
|
||||
|
||||
class TextEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
0
api/core/model_runtime/errors/__init__.py
Normal file
34
api/core/model_runtime/errors/invoke.py
Normal file
@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InvokeError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
self.description = description
|
||||
|
||||
|
||||
class InvokeConnectionError(InvokeError):
|
||||
"""Raised when the Invoke returns connection error."""
|
||||
description = "Connection Error"
|
||||
|
||||
|
||||
class InvokeServerUnavailableError(InvokeError):
|
||||
"""Raised when the Invoke returns server unavailable error."""
|
||||
description = "Server Unavailable Error"
|
||||
|
||||
|
||||
class InvokeRateLimitError(InvokeError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
description = "Rate Limit Error"
|
||||
|
||||
|
||||
class InvokeAuthorizationError(InvokeError):
|
||||
"""Raised when the Invoke returns authorization error."""
|
||||
description = "Incorrect model credentials provided, please check and try again. "
|
||||
|
||||
|
||||
class InvokeBadRequestError(InvokeError):
|
||||
"""Raised when the Invoke returns bad request."""
|
||||
description = "Bad Request Error"
|
||||
5
api/core/model_runtime/errors/validate.py
Normal file
@ -0,0 +1,5 @@
|
||||
class CredentialsValidateFailedError(Exception):
|
||||
"""
|
||||
Credentials validate failed error
|
||||
"""
|
||||
pass
|
||||
328
api/core/model_runtime/model_providers/__base/ai_model.py
Normal file
@ -0,0 +1,328 @@
|
||||
import decimal
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from core.model_runtime.entities.model_entities import PriceInfo, AIModelEntity, PriceType, PriceConfig, \
|
||||
DefaultParameterName, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
|
||||
|
||||
class AIModel(ABC):
|
||||
"""
|
||||
Base class for all models.
|
||||
"""
|
||||
model_type: ModelType
|
||||
model_schemas: list[AIModelEntity] = None
|
||||
started_at: float = 0
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_invoke_error(self, error: Exception) -> InvokeError:
|
||||
"""
|
||||
Transform invoke error to unified error
|
||||
|
||||
:param error: model invoke error
|
||||
:return: unified error
|
||||
"""
|
||||
for invoke_error, model_errors in self._invoke_error_mapping.items():
|
||||
if isinstance(error, tuple(model_errors)):
|
||||
if invoke_error == InvokeAuthorizationError:
|
||||
return invoke_error(description="Incorrect model credentials provided, please check and try again. ")
|
||||
|
||||
return invoke_error(description=f"{invoke_error.description}: {str(error)}")
|
||||
|
||||
return InvokeError(description=f"Error: {str(error)}")
|
||||
|
||||
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
|
||||
"""
|
||||
Get price for given model and tokens
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param price_type: price type
|
||||
:param tokens: number of tokens
|
||||
:return: price info
|
||||
"""
|
||||
# get model schema
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
# get price info from predefined model schema
|
||||
price_config: Optional[PriceConfig] = None
|
||||
if model_schema:
|
||||
price_config: PriceConfig = model_schema.pricing
|
||||
|
||||
# get unit price
|
||||
unit_price = None
|
||||
if price_config:
|
||||
if price_type == PriceType.INPUT:
|
||||
unit_price = price_config.input
|
||||
elif price_type == PriceType.OUTPUT and price_config.output is not None:
|
||||
unit_price = price_config.output
|
||||
|
||||
if unit_price is None:
|
||||
return PriceInfo(
|
||||
unit_price=decimal.Decimal('0.0'),
|
||||
unit=decimal.Decimal('0.0'),
|
||||
total_amount=decimal.Decimal('0.0'),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
# calculate total amount
|
||||
total_amount = tokens * unit_price * price_config.unit
|
||||
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
return PriceInfo(
|
||||
unit_price=unit_price,
|
||||
unit=price_config.unit,
|
||||
total_amount=total_amount,
|
||||
currency=price_config.currency,
|
||||
)
|
||||
|
||||
def predefined_models(self) -> list[AIModelEntity]:
|
||||
"""
|
||||
Get all predefined models for given provider.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.model_schemas:
|
||||
return self.model_schemas
|
||||
|
||||
model_schemas = []
|
||||
|
||||
# get module name
|
||||
model_type = self.__class__.__module__.split('.')[-1]
|
||||
|
||||
# get provider name
|
||||
provider_name = self.__class__.__module__.split('.')[-3]
|
||||
|
||||
# get the path of current classes
|
||||
current_path = os.path.abspath(__file__)
|
||||
# get parent path of the current path
|
||||
provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type)
|
||||
|
||||
# get all yaml files path under provider_model_type_path that do not start with __
|
||||
model_schema_yaml_paths = [
|
||||
os.path.join(provider_model_type_path, model_schema_yaml)
|
||||
for model_schema_yaml in os.listdir(provider_model_type_path)
|
||||
if not model_schema_yaml.startswith('__')
|
||||
and not model_schema_yaml.startswith('_')
|
||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
||||
and model_schema_yaml.endswith('.yaml')
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
position_file_path = os.path.join(provider_model_type_path, '_position.yaml')
|
||||
|
||||
# read _position.yaml file
|
||||
position_map = {}
|
||||
if os.path.exists(position_file_path):
|
||||
with open(position_file_path, 'r') as f:
|
||||
position_map = yaml.safe_load(f)
|
||||
|
||||
# traverse all model_schema_yaml_paths
|
||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||
# read yaml data from yaml file
|
||||
with open(model_schema_yaml_path, 'r') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in yaml_data.get('parameter_rules', []):
|
||||
if 'use_template' in parameter_rule:
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template'])
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
copy_default_parameter_rule = default_parameter_rule.copy()
|
||||
copy_default_parameter_rule.update(parameter_rule)
|
||||
parameter_rule = copy_default_parameter_rule
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if 'label' not in parameter_rule:
|
||||
parameter_rule['label'] = {
|
||||
'zh_Hans': parameter_rule['name'],
|
||||
'en_US': parameter_rule['name']
|
||||
}
|
||||
|
||||
new_parameter_rules.append(parameter_rule)
|
||||
|
||||
yaml_data['parameter_rules'] = new_parameter_rules
|
||||
|
||||
if 'label' not in yaml_data:
|
||||
yaml_data['label'] = {
|
||||
'zh_Hans': yaml_data['model'],
|
||||
'en_US': yaml_data['model']
|
||||
}
|
||||
|
||||
yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
model_schema = AIModelEntity(**yaml_data)
|
||||
except Exception as e:
|
||||
model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
|
||||
raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:'
|
||||
f' {str(e)}')
|
||||
|
||||
# cache model schema
|
||||
model_schemas.append(model_schema)
|
||||
|
||||
# resort model schemas by position
|
||||
if position_map:
|
||||
model_schemas.sort(key=lambda x: position_map.get(x.model, 999))
|
||||
|
||||
# cache model schemas
|
||||
self.model_schemas = model_schemas
|
||||
|
||||
return model_schemas
|
||||
|
||||
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get model schema by model name and credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
# get predefined models (predefined_models)
|
||||
models = self.predefined_models()
|
||||
|
||||
model_map = {model.model: model for model in models}
|
||||
if model in model_map:
|
||||
return model_map[model]
|
||||
|
||||
if credentials:
|
||||
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
|
||||
if model_schema:
|
||||
return model_schema
|
||||
|
||||
return None
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
if 'schema' in credentials:
|
||||
schema_dict = json.loads(credentials['schema'])
|
||||
|
||||
try:
|
||||
model_instance = AIModelEntity.parse_obj(schema_dict)
|
||||
return model_instance
|
||||
except ValidationError as e:
|
||||
logging.exception(f"Invalid model schema for {model}")
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema and fill in the template
|
||||
"""
|
||||
schema = self.get_customizable_model_schema(model, credentials)
|
||||
|
||||
if not schema:
|
||||
return None
|
||||
|
||||
# fill in the template
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in schema.parameter_rules:
|
||||
if parameter_rule.use_template:
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
if not parameter_rule.max:
|
||||
parameter_rule.max = default_parameter_rule['max']
|
||||
if not parameter_rule.min:
|
||||
parameter_rule.min = default_parameter_rule['min']
|
||||
if not parameter_rule.precision:
|
||||
parameter_rule.default = default_parameter_rule['default']
|
||||
if not parameter_rule.precision:
|
||||
parameter_rule.precision = default_parameter_rule['precision']
|
||||
if not parameter_rule.required:
|
||||
parameter_rule.required = default_parameter_rule['required']
|
||||
if not parameter_rule.help:
|
||||
parameter_rule.help = I18nObject(
|
||||
en_US=default_parameter_rule['help']['en_US'],
|
||||
)
|
||||
if not parameter_rule.help.en_US:
|
||||
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
|
||||
if not parameter_rule.help.zh_Hans:
|
||||
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
new_parameter_rules.append(parameter_rule)
|
||||
|
||||
schema.parameter_rules = new_parameter_rules
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict:
|
||||
"""
|
||||
Get default parameter rule for given name
|
||||
|
||||
:param name: parameter name
|
||||
:return: parameter rule
|
||||
"""
|
||||
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
||||
|
||||
if not default_parameter_rule:
|
||||
raise Exception(f'Invalid model parameter rule name {name}')
|
||||
|
||||
return default_parameter_rule
|
||||
|
||||
def _get_num_tokens_by_gpt2(self, text: str) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages by gpt2
|
||||
Some provider models do not provide an interface for obtaining the number of tokens.
|
||||
Here, the gpt2 tokenizer is used to calculate the number of tokens.
|
||||
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
|
||||
|
||||
:param text: plain text of prompt. You need to convert the original message to plain text
|
||||
:return: number of tokens
|
||||
"""
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
BIN
api/core/model_runtime/model_providers/__base/audio.mp3
Normal file
@ -0,0 +1,557 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Generator, Union, List
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, PriceType, ParameterType, ParameterRule, \
|
||||
ModelType
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMUsage, \
|
||||
LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
Model class for large language model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.LLM
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# validate and filter model parameters
|
||||
if model_parameters is None:
|
||||
model_parameters = {}
|
||||
|
||||
model_parameters = self._validate_and_filter_model_parameters(model, model_parameters, credentials)
|
||||
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
if bool(os.environ.get("DEBUG")):
|
||||
callbacks.append(LoggingCallback())
|
||||
|
||||
# trigger before invoke callbacks
|
||||
self._trigger_before_invoke_callbacks(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
try:
|
||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
except Exception as e:
|
||||
self._trigger_invoke_error_callbacks(
|
||||
model=model,
|
||||
ex=e,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
if stream and isinstance(result, Generator):
|
||||
return self._invoke_result_generator(
|
||||
model=model,
|
||||
result=result,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=result,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
|
||||
"""
|
||||
Invoke result generator
|
||||
|
||||
:param result: result generator
|
||||
:return: result generator
|
||||
"""
|
||||
prompt_message = AssistantPromptMessage(
|
||||
content=""
|
||||
)
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
|
||||
for chunk in result:
|
||||
try:
|
||||
yield chunk
|
||||
|
||||
self._trigger_new_chunk_callbacks(
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
prompt_message.content += chunk.delta.message.content
|
||||
real_model = chunk.model
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
|
||||
if chunk.system_fingerprint:
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=LLMResult(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=system_fingerprint
|
||||
),
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _llm_result_to_stream(self, result: LLMResult) -> Generator:
|
||||
"""
|
||||
Transform llm result to stream
|
||||
|
||||
:param result: llm result
|
||||
:return: stream
|
||||
"""
|
||||
index = 0
|
||||
|
||||
tool_calls = result.message.tool_calls
|
||||
|
||||
for word in result.message.content:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=word,
|
||||
tool_calls=tool_calls if index == (len(result.message.content) - 1) else []
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
system_fingerprint=result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRule]:
|
||||
"""
|
||||
Get parameter rules
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: parameter rules
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema:
|
||||
return model_schema.parameter_rules
|
||||
|
||||
return []
|
||||
|
||||
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
|
||||
"""
|
||||
Get model mode
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model mode
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
mode = LLMMode.CHAT
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE])
|
||||
|
||||
return mode
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_tokens: prompt tokens
|
||||
:param completion_tokens: completion tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get prompt price info
|
||||
prompt_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
# get completion price info
|
||||
completion_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.OUTPUT,
|
||||
tokens=completion_tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=prompt_price_info.unit_price,
|
||||
prompt_price_unit=prompt_price_info.unit,
|
||||
prompt_price=prompt_price_info.total_amount,
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=completion_price_info.unit_price,
|
||||
completion_price_unit=completion_price_info.unit,
|
||||
completion_price=completion_price_info.total_amount,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
||||
currency=prompt_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
"""
|
||||
Trigger before invoke callbacks
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_before_invoke(
|
||||
llm_instance=self,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}")
|
||||
|
||||
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
"""
|
||||
Trigger new chunk callbacks
|
||||
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_new_chunk(
|
||||
llm_instance=self,
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}")
|
||||
|
||||
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
"""
|
||||
Trigger after invoke callbacks
|
||||
|
||||
:param model: model name
|
||||
:param result: result
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_after_invoke(
|
||||
llm_instance=self,
|
||||
result=result,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}")
|
||||
|
||||
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
"""
|
||||
Trigger invoke error callbacks
|
||||
|
||||
:param model: model name
|
||||
:param ex: exception
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_invoke_error(
|
||||
llm_instance=self,
|
||||
ex=ex,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}")
|
||||
|
||||
def _validate_and_filter_model_parameters(self, model: str, model_parameters: dict, credentials: dict) -> dict:
|
||||
"""
|
||||
Validate model parameters
|
||||
|
||||
:param model: model name
|
||||
:param model_parameters: model parameters
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
parameter_rules = self.get_parameter_rules(model, credentials)
|
||||
|
||||
# validate model parameters
|
||||
filtered_model_parameters = {}
|
||||
for parameter_rule in parameter_rules:
|
||||
parameter_name = parameter_rule.name
|
||||
parameter_value = model_parameters.get(parameter_name)
|
||||
if parameter_value is None:
|
||||
if parameter_rule.use_template and parameter_rule.use_template in model_parameters:
|
||||
# if parameter value is None, use template value variable name instead
|
||||
parameter_value = model_parameters[parameter_rule.use_template]
|
||||
else:
|
||||
if parameter_rule.required:
|
||||
if parameter_rule.default is not None:
|
||||
filtered_model_parameters[parameter_name] = parameter_rule.default
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Model Parameter {parameter_name} is required.")
|
||||
else:
|
||||
continue
|
||||
|
||||
# validate parameter value type
|
||||
if parameter_rule.type == ParameterType.INT:
|
||||
if not isinstance(parameter_value, int):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be int.")
|
||||
|
||||
# validate parameter value range
|
||||
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
|
||||
|
||||
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
|
||||
elif parameter_rule.type == ParameterType.FLOAT:
|
||||
if not isinstance(parameter_value, (float, int)):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be float.")
|
||||
|
||||
# validate parameter value precision
|
||||
if parameter_rule.precision is not None:
|
||||
if parameter_rule.precision == 0:
|
||||
if parameter_value != int(parameter_value):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be int.")
|
||||
else:
|
||||
if parameter_value != round(parameter_value, parameter_rule.precision):
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.")
|
||||
|
||||
# validate parameter value range
|
||||
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
|
||||
|
||||
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
|
||||
elif parameter_rule.type == ParameterType.BOOLEAN:
|
||||
if not isinstance(parameter_value, bool):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be bool.")
|
||||
elif parameter_rule.type == ParameterType.STRING:
|
||||
if not isinstance(parameter_value, str):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be string.")
|
||||
|
||||
# validate options
|
||||
if parameter_rule.options and parameter_value not in parameter_rule.options:
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
|
||||
else:
|
||||
raise ValueError(f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported.")
|
||||
|
||||
filtered_model_parameters[parameter_name] = parameter_value
|
||||
|
||||
return filtered_model_parameters
|
||||
125
api/core/model_runtime/model_providers/__base/model_provider.py
Normal file
@ -0,0 +1,125 @@
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict
|
||||
|
||||
import yaml
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType, AIModelEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
provider_schema: ProviderEntity = None
|
||||
model_instance_map: Dict[str, AIModel] = {}
|
||||
|
||||
@abstractmethod
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_provider_schema(self) -> ProviderEntity:
|
||||
"""
|
||||
Get provider schema
|
||||
|
||||
:return: provider schema
|
||||
"""
|
||||
if self.provider_schema:
|
||||
return self.provider_schema
|
||||
|
||||
# get dirname of the current path
|
||||
provider_name = self.__class__.__module__.split('.')[-1]
|
||||
|
||||
# get the path of the model_provider classes
|
||||
base_path = os.path.abspath(__file__)
|
||||
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
|
||||
|
||||
# read provider schema from yaml file
|
||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||
yaml_data = {}
|
||||
if os.path.exists(yaml_path):
|
||||
with open(yaml_path, 'r') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
provider_schema = ProviderEntity(**yaml_data)
|
||||
except Exception as e:
|
||||
raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}')
|
||||
|
||||
# cache schema
|
||||
self.provider_schema = provider_schema
|
||||
|
||||
return provider_schema
|
||||
|
||||
def models(self, model_type: ModelType) -> list[AIModelEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
:param model_type: model type defined in `ModelType`
|
||||
:return: list of models
|
||||
"""
|
||||
provider_schema = self.get_provider_schema()
|
||||
if model_type not in provider_schema.supported_model_types:
|
||||
return []
|
||||
|
||||
# get model instance of the model type
|
||||
model_instance = self.get_model_instance(model_type)
|
||||
|
||||
# get predefined models (predefined_models)
|
||||
models = model_instance.predefined_models()
|
||||
|
||||
# return models
|
||||
return models
|
||||
|
||||
def get_model_instance(self, model_type: ModelType) -> AIModel:
|
||||
"""
|
||||
Get model instance
|
||||
|
||||
:param model_type: model type defined in `ModelType`
|
||||
:return:
|
||||
"""
|
||||
# get dirname of the current path
|
||||
provider_name = self.__class__.__module__.split('.')[-1]
|
||||
|
||||
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
|
||||
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
|
||||
|
||||
# get the path of the model type classes
|
||||
base_path = os.path.abspath(__file__)
|
||||
model_type_name = model_type.value.replace('-', '_')
|
||||
model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name)
|
||||
model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py')
|
||||
|
||||
if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
|
||||
raise Exception(f'Invalid model type {model_type} for provider {provider_name}')
|
||||
|
||||
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
||||
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
|
||||
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
model_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
|
||||
and obj != AIModel):
|
||||
model_class = obj
|
||||
break
|
||||
|
||||
if not model_class:
|
||||
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
|
||||
|
||||
model_instance_map = model_class()
|
||||
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
|
||||
|
||||
return model_instance_map
|
||||
@ -0,0 +1,48 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class ModerationModel(AIModel):
|
||||
"""
|
||||
Model class for moderation model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
return self._invoke(model, credentials, text, user)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class RerankModel(AIModel):
|
||||
"""
|
||||
Base Model class for rerank model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.RERANK
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
return self._invoke(model, credentials, query, docs, score_threshold, top_n, user)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -0,0 +1,57 @@
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, IO
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class Speech2TextModel(AIModel):
|
||||
"""
|
||||
Model class for speech2text model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
try:
|
||||
return self._invoke(model, credentials, file, user)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_demo_file_path(self) -> str:
|
||||
"""
|
||||
Get demo file for given model
|
||||
|
||||
:return: demo file
|
||||
"""
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Construct the path to the audio file
|
||||
return os.path.join(current_dir, 'audio.mp3')
|
||||
@ -0,0 +1,90 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class TextEmbeddingModel(AIModel):
|
||||
"""
|
||||
Model class for text embedding model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
return self._invoke(model, credentials, texts, user)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_context_size(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get context size for given embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: context size
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
|
||||
return 1000
|
||||
|
||||
def _get_max_chunks(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get max chunks for given embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: max chunks
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
|
||||
return 1
|
||||
50001
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/merges.txt
Normal file
@ -0,0 +1,23 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,33 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"errors": "replace",
|
||||
"model_max_length": 1024,
|
||||
"pad_token": null,
|
||||
"tokenizer_class": "GPT2Tokenizer",
|
||||
"unk_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
50259
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/vocab.json
Normal file
@ -0,0 +1,32 @@
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
||||
from os.path import join, abspath, dirname
|
||||
from typing import Any
|
||||
from threading import Lock
|
||||
|
||||
_tokenizer = None
|
||||
_lock = Lock()
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||
"""
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text, verbose=False)
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
global _tokenizer, _lock
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), 'gpt2')
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
|
||||
return _tokenizer
|
||||
3
api/core/model_runtime/model_providers/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
model_provider_factory = ModelProviderFactory()
|
||||
19
api/core/model_runtime/model_providers/_position.yaml
Normal file
@ -0,0 +1,19 @@
|
||||
openai: 0
|
||||
anthropic: 1
|
||||
azure_openai: 2
|
||||
google: 3
|
||||
replicate: 4
|
||||
huggingface_hub: 5
|
||||
cohere: 6
|
||||
zhipuai: 7
|
||||
baichuan: 8
|
||||
spark: 9
|
||||
minimax: 10
|
||||
tongyi: 11
|
||||
wenxin: 12
|
||||
jina: 13
|
||||
chatglm: 14
|
||||
xinference: 15
|
||||
openllm: 16
|
||||
localai: 17
|
||||
openai_api_compatible: 18
|
||||
@ -0,0 +1,78 @@
|
||||
<svg width="90" height="20" viewBox="0 0 90 20" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_8587_60274)">
|
||||
<mask id="mask0_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M89.375 4.99805H0V14.998H89.375V4.99805Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask0_8587_60274)">
|
||||
<mask id="mask1_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99609H89.375V14.9961H0V4.99609Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask1_8587_60274)">
|
||||
<mask id="mask2_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99414H89.375V14.9941H0V4.99414Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask2_8587_60274)">
|
||||
<mask id="mask3_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask3_8587_60274)">
|
||||
<path d="M18.1273 11.9244L13.7773 5.15625H11.4297V14.825H13.4321V8.05688L17.7821 14.825H20.1297V5.15625H18.1273V11.9244Z" fill="black" fill-opacity="0.92"/>
|
||||
</g>
|
||||
<mask id="mask4_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask4_8587_60274)">
|
||||
<path d="M21.7969 7.02094H25.0423V14.825H27.1139V7.02094H30.3594V5.15625H21.7969V7.02094Z" fill="black" fill-opacity="0.92"/>
|
||||
</g>
|
||||
<mask id="mask5_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask5_8587_60274)">
|
||||
<path d="M38.6442 9.00994H34.0871V5.15625H32.0156V14.825H34.0871V10.8746H38.6442V14.825H40.7156V5.15625H38.6442V9.00994Z" fill="black" fill-opacity="0.92"/>
|
||||
</g>
|
||||
<mask id="mask6_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask6_8587_60274)">
|
||||
<path d="M45.3376 7.02094H47.893C48.9152 7.02094 49.4539 7.39387 49.4539 8.09831C49.4539 8.80275 48.9152 9.17569 47.893 9.17569H45.3376V7.02094ZM51.5259 8.09831C51.5259 6.27506 50.186 5.15625 47.9897 5.15625H43.2656V14.825H45.3376V11.0404H47.6443L49.7164 14.825H52.0094L49.715 10.7521C50.8666 10.3094 51.5259 9.37721 51.5259 8.09831Z" fill="black" fill-opacity="0.92"/>
|
||||
</g>
|
||||
<mask id="mask7_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask7_8587_60274)">
|
||||
<path d="M57.8732 13.0565C56.2438 13.0565 55.2496 11.8963 55.2496 10.004C55.2496 8.08416 56.2438 6.92394 57.8732 6.92394C59.4887 6.92394 60.4691 8.08416 60.4691 10.004C60.4691 11.8963 59.4887 13.0565 57.8732 13.0565ZM57.8732 4.99023C55.0839 4.99023 53.1094 7.06206 53.1094 10.004C53.1094 12.9184 55.0839 14.9902 57.8732 14.9902C60.6486 14.9902 62.6094 12.9184 62.6094 10.004C62.6094 7.06206 60.6486 4.99023 57.8732 4.99023Z" fill="black" fill-opacity="0.92"/>
|
||||
</g>
|
||||
<mask id="mask8_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask8_8587_60274)">
|
||||
<path d="M69.1794 9.45194H66.6233V7.02094H69.1794C70.2019 7.02094 70.7407 7.43532 70.7407 8.23644C70.7407 9.03756 70.2019 9.45194 69.1794 9.45194ZM69.2762 5.15625H64.5508V14.825H66.6233V11.3166H69.2762C71.473 11.3166 72.8133 10.1564 72.8133 8.23644C72.8133 6.3165 71.473 5.15625 69.2762 5.15625Z" fill="black" fill-opacity="0.92"/>
|
||||
</g>
|
||||
<mask id="mask9_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask9_8587_60274)">
|
||||
<path d="M86.8413 11.5786C86.4823 12.5179 85.7642 13.0565 84.7837 13.0565C83.1542 13.0565 82.16 11.8963 82.16 10.004C82.16 8.08416 83.1542 6.92394 84.7837 6.92394C85.7642 6.92394 86.4823 7.46261 86.8413 8.40183H89.0369C88.4984 6.33002 86.8827 4.99023 84.7837 4.99023C81.9942 4.99023 80.0195 7.06206 80.0195 10.004C80.0195 12.9184 81.9942 14.9902 84.7837 14.9902C86.8965 14.9902 88.5122 13.6366 89.0508 11.5786H86.8413Z" fill="black" fill-opacity="0.92"/>
|
||||
</g>
|
||||
<mask id="mask10_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask10_8587_60274)">
|
||||
<path d="M73.6484 5.15625L77.5033 14.825H79.6172L75.7624 5.15625H73.6484Z" fill="black" fill-opacity="0.92"/>
|
||||
</g>
|
||||
<mask id="mask11_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
|
||||
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask11_8587_60274)">
|
||||
<path d="M3.64038 10.9989L4.95938 7.60106L6.27838 10.9989H3.64038ZM3.85422 5.15625L0 14.825H2.15505L2.9433 12.7946H6.97558L7.76371 14.825H9.91875L6.06453 5.15625H3.85422Z" fill="black" fill-opacity="0.92"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_8587_60274">
|
||||
<rect width="89.375" height="10" fill="white" transform="translate(0 5)"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 5.3 KiB |
@ -0,0 +1,4 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="24" height="24" rx="6" fill="#CA9F7B"/>
|
||||
<path d="M15.3843 6.43481H12.9687L17.3739 17.5652H19.7896L15.3843 6.43481ZM8.40522 6.43481L4 17.5652H6.4633L7.36417 15.2279H11.9729L12.8737 17.5652H15.337L10.9318 6.43481H8.40522ZM8.16104 13.1607L9.66852 9.24907L11.176 13.1607H8.16104Z" fill="#191918"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 410 B |
@ -0,0 +1,31 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `claude-instant-1` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='claude-instant-1',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@ -0,0 +1,39 @@
|
||||
provider: anthropic
|
||||
label:
|
||||
en_US: Anthropic
|
||||
description:
|
||||
en_US: Anthropic’s powerful models, such as Claude 2 and Claude Instant.
|
||||
zh_Hans: Anthropic 的强大模型,例如 Claude 2 和 Claude Instant。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#F0F0EB"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from Anthropic
|
||||
zh_Hans: 从 Anthropic 获取 API Key
|
||||
url:
|
||||
en_US: https://console.anthropic.com/account/keys
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: anthropic_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: anthropic_api_url
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API URL
|
||||
en_US: Enter your API URL
|
||||
@ -0,0 +1,34 @@
|
||||
model: claude-2.1
|
||||
label:
|
||||
en_US: claude-2.1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '8.00'
|
||||
output: '24.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,34 @@
|
||||
model: claude-2
|
||||
label:
|
||||
en_US: claude-2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 100000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '8.00'
|
||||
output: '24.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,33 @@
|
||||
model: claude-instant-1
|
||||
label:
|
||||
en_US: claude-instant-1
|
||||
model_type: llm
|
||||
features: []
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 100000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '1.63'
|
||||
output: '5.51'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
298
api/core/model_runtime/model_providers/anthropic/llm/llm.py
Normal file
@ -0,0 +1,298 @@
|
||||
from typing import Optional, Generator, Union, List
|
||||
|
||||
import anthropic
|
||||
from anthropic import Anthropic, Stream
|
||||
from anthropic.types import completion_create_params, Completion
|
||||
from httpx import Timeout
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \
|
||||
SystemPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
||||
LLMResultChunkDelta
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
|
||||
|
||||
client = Anthropic(api_key="")
|
||||
return client.count_tokens(prompt)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[
|
||||
UserPromptMessage(content="ping"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0,
|
||||
"max_tokens_to_sample": 20,
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
client = Anthropic(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
|
||||
|
||||
response = client.completions.create(
|
||||
model=model,
|
||||
prompt=self._convert_messages_to_prompt_anthropic(prompt_messages),
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.completion
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
index = -1
|
||||
for chunk in response:
|
||||
content = chunk.completion
|
||||
if chunk.stop_reason is None and (content is None or content == ''):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=content if content else '',
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
if chunk.stop_reason is not None:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=chunk.stop_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
)
|
||||
)
|
||||
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['anthropic_api_key'],
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
if 'anthropic_api_url' in credentials and credentials['anthropic_api_url']:
|
||||
credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/')
|
||||
credentials_kwargs['base_url'] = credentials['anthropic_api_url']
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nHuman:"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
content = message.content
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt_anthropic(self, messages: List[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
if not isinstance(messages[-1], AssistantPromptMessage):
|
||||
messages.append(AssistantPromptMessage(content=""))
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
anthropic.APIConnectionError,
|
||||
anthropic.APITimeoutError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
anthropic.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
anthropic.RateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
anthropic.AuthenticationError,
|
||||
anthropic.PermissionDeniedError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
anthropic.BadRequestError,
|
||||
anthropic.NotFoundError,
|
||||
anthropic.UnprocessableEntityError,
|
||||
anthropic.APIError
|
||||
]
|
||||
}
|
||||
|
After Width: | Height: | Size: 4.9 KiB |
@ -0,0 +1,8 @@
|
||||
<svg width="21" height="22" viewBox="0 0 21 22" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="Microsfot">
|
||||
<rect id="Rectangle 1010" y="0.5" width="10" height="10" fill="#EF4F21"/>
|
||||
<rect id="Rectangle 1012" y="11.5" width="10" height="10" fill="#03A4EE"/>
|
||||
<rect id="Rectangle 1011" x="11" y="0.5" width="10" height="10" fill="#7EB903"/>
|
||||
<rect id="Rectangle 1013" x="11" y="11.5" width="10" height="10" fill="#FBB604"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 439 B |
@ -0,0 +1,46 @@
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
|
||||
from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
|
||||
|
||||
|
||||
class _CommonAzureOpenAI:
|
||||
@staticmethod
|
||||
def _to_credential_kwargs(credentials: dict) -> dict:
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['openai_api_key'],
|
||||
"azure_endpoint": credentials['openai_api_base'],
|
||||
"api_version": AZURE_OPENAI_API_VERSION,
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
openai.APIConnectionError,
|
||||
openai.APITimeoutError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
openai.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
openai.RateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
openai.AuthenticationError,
|
||||
openai.PermissionDeniedError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
openai.BadRequestError,
|
||||
openai.NotFoundError,
|
||||
openai.UnprocessableEntityError,
|
||||
openai.APIError
|
||||
]
|
||||
}
|
||||
475
api/core/model_runtime/model_providers/azure_openai/_constant.py
Normal file
@ -0,0 +1,475 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \
|
||||
DefaultParameterName, PriceConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-12-01-preview'
|
||||
|
||||
|
||||
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
|
||||
rule = ParameterRule(
|
||||
name='max_tokens',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS],
|
||||
)
|
||||
rule.default = default
|
||||
rule.min = min_val
|
||||
rule.max = max_val
|
||||
return rule
|
||||
|
||||
|
||||
class AzureBaseModel(BaseModel):
|
||||
base_model_name: str
|
||||
entity: AIModelEntity
|
||||
|
||||
|
||||
LLM_BASE_MODELS = [
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-35-turbo',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
'mode': LLMMode.CHAT.value,
|
||||
'context_size': 4096,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096)
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.001,
|
||||
output=0.002,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-35-turbo-16k',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
'mode': LLMMode.CHAT.value,
|
||||
'context_size': 16385,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=16385)
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.003,
|
||||
output=0.004,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
'mode': LLMMode.CHAT.value,
|
||||
'context_size': 8192,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=8192),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.03,
|
||||
output=0.06,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-32k',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
'mode': LLMMode.CHAT.value,
|
||||
'context_size': 32768,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=32768),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.06,
|
||||
output=0.12,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-1106-preview',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
'mode': LLMMode.CHAT.value,
|
||||
'context_size': 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=128000),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.01,
|
||||
output=0.03,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-vision-preview',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.VISION
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
'mode': LLMMode.CHAT.value,
|
||||
'context_size': 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=128000),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.01,
|
||||
output=0.03,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-35-turbo-instruct',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
'mode': LLMMode.COMPLETION.value,
|
||||
'context_size': 4096,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.0015,
|
||||
output=0.002,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
EMBEDDING_BASE_MODELS = [
|
||||
AzureBaseModel(
|
||||
base_model_name='text-embedding-ada-002',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
'context_size': 8097,
|
||||
'max_chunks': 32,
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.0001,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -0,0 +1,11 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
@ -0,0 +1,104 @@
|
||||
provider: azure_openai
|
||||
label:
|
||||
en_US: Azure OpenAI Service Model
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#E3F0FF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Azure
|
||||
zh_Hans: 从 Azure 获取 API Key
|
||||
url:
|
||||
en_US: https://azure.microsoft.com/en-us/products/ai-services/openai-service
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Deployment Name
|
||||
zh_Hans: 部署名称
|
||||
placeholder:
|
||||
en_US: Enter your Deployment Name here, matching the Azure deployment name.
|
||||
zh_Hans: 在此输入您的部署名称,与 Azure 部署名称匹配。
|
||||
credential_form_schemas:
|
||||
- variable: openai_api_base
|
||||
label:
|
||||
en_US: API Endpoint URL
|
||||
zh_Hans: API 域名
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: '在此输入您的 API 域名,如:https://example.com/xxx'
|
||||
en_US: 'Enter your API Endpoint, eg: https://example.com/xxx'
|
||||
- variable: openai_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API key here
|
||||
- variable: base_model_name
|
||||
label:
|
||||
en_US: Base Model
|
||||
zh_Hans: 基础模型
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: gpt-35-turbo
|
||||
value: gpt-35-turbo
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-16k
|
||||
value: gpt-35-turbo-16k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4
|
||||
value: gpt-4
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-32k
|
||||
value: gpt-4-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-1106-preview
|
||||
value: gpt-4-1106-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-vision-preview
|
||||
value: gpt-4-vision-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-instruct
|
||||
value: gpt-35-turbo-instruct
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: text-embedding-ada-002
|
||||
value: text-embedding-ada-002
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型版本
|
||||
en_US: Enter your model version
|
||||
627
api/core/model_runtime/model_providers/azure_openai/llm/llm.py
Normal file
@ -0,0 +1,627 @@
|
||||
import logging
|
||||
from typing import Optional, Generator, Union, List, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import AzureOpenAI, Stream
|
||||
from openai.types import Completion
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletion, ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaFunctionCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \
|
||||
LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \
|
||||
UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \
|
||||
TextPromptMessageContent, SystemPromptMessage, ToolPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
|
||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
return self._chat_generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
else:
|
||||
# text completion model
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
||||
model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get(
|
||||
ModelPropertyKey.MODE)
|
||||
|
||||
if model_mode == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
else:
|
||||
# text completion model, do not support tool calling
|
||||
return self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
if 'openai_api_base' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
|
||||
|
||||
if 'openai_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
|
||||
|
||||
if 'base_model_name' not in credentials:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
|
||||
if not ai_model_entity:
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
|
||||
try:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": 'ping'}],
|
||||
model=model,
|
||||
temperature=0,
|
||||
max_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
# text completion model
|
||||
client.completions.create(
|
||||
prompt='ping',
|
||||
model=model,
|
||||
temperature=0,
|
||||
max_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs['stop'] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
|
||||
# text completion model
|
||||
response = client.completions.create(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
assistant_text = response.choices[0].text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
full_text = ''
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.text is None or delta.text == ''):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = delta.text if delta.text else ''
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text
|
||||
)
|
||||
|
||||
full_text += text
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# calculate num tokens
|
||||
if chunk.usage:
|
||||
# transform usage
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(credentials, full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
response_format = model_parameters.get("response_format")
|
||||
if response_format:
|
||||
if response_format == "json_object":
|
||||
response_format = {"type": "json_object"}
|
||||
else:
|
||||
response_format = {"type": "text"}
|
||||
|
||||
model_parameters["response_format"] = response_format
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||
extra_model_kwargs['functions'] = [{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
} for tool in tools]
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs['stop'] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
|
||||
# chat model
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
|
||||
|
||||
assistant_message = response.choices[0].message
|
||||
# assistant_message_tool_calls = assistant_message.tool_calls
|
||||
assistant_message_function_call = assistant_message.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
tool_calls = [function_call] if function_call else []
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_message.content,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
|
||||
|
||||
full_assistant_content = ''
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
||||
continue
|
||||
|
||||
# assistant_message_tool_calls = delta.delta.tool_calls
|
||||
assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
tool_calls = [function_call] if function_call else []
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else '',
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ''
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
|
||||
full_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_assistant_content,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
|
||||
-> list[AssistantPromptMessage.ToolCall]:
|
||||
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name,
|
||||
arguments=response_tool_call.function.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id,
|
||||
type=response_tool_call.type,
|
||||
function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
|
||||
-> AssistantPromptMessage.ToolCall:
|
||||
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function_call.name,
|
||||
arguments=response_function_call.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_function_call.name,
|
||||
type="function",
|
||||
function=function
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
||||
@staticmethod
|
||||
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": message_content.data,
|
||||
"detail": message_content.detail.value
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
# message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in
|
||||
# message.tool_calls]
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
# message_dict = {
|
||||
# "role": "tool",
|
||||
# "content": message.content,
|
||||
# "tool_call_id": message.tool_call_id
|
||||
# }
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.tool_call_id
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name is not None:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(self, credentials: dict, text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(credentials['base_model_name'])
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(self, credentials: dict, messages: List[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
model = credentials['base_model_name']
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-35-turbo-0301"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
|
||||
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(tool.get("type")))
|
||||
num_tokens += len(encoding.encode('function'))
|
||||
|
||||
# calculate num tokens for function object
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(tool.name))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(tool.description))
|
||||
parameters = tool.parameters
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in LLM_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity.entity.model = model
|
||||
ai_model_entity.entity.label.en_US = model
|
||||
ai_model_entity.entity.label.zh_Hans = model
|
||||
return ai_model_entity
|
||||
|
||||
return None
|
||||
@ -0,0 +1,195 @@
|
||||
import base64
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType, AIModelEntity
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_BASE_MODELS, AzureBaseModel
|
||||
|
||||
|
||||
class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
base_model_name = credentials['base_model_name']
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
|
||||
extra_model_kwargs['encoding_format'] = 'base64'
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
tokens = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(base_model_name)
|
||||
except KeyError:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
token = enc.encode(
|
||||
text
|
||||
)
|
||||
for j in range(0, len(token), context_size):
|
||||
tokens += [token[j: j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(tokens), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
embeddings, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=tokens[i: i + max_chunks],
|
||||
extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += [data for data in embeddings]
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=[""],
|
||||
extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
average = embeddings[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
model=base_model_name
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(credentials['base_model_name'])
|
||||
except KeyError:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
# calculate the number of tokens in the encoded text
|
||||
tokenized_text = enc.encode(text)
|
||||
total_num_tokens += len(tokenized_text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
if 'openai_api_base' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
|
||||
|
||||
if 'openai_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
|
||||
|
||||
if 'base_model_name' not in credentials:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
|
||||
if not self._get_ai_model_entity(credentials['base_model_name'], model):
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
|
||||
try:
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
|
||||
self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=['ping'],
|
||||
extra_model_kwargs={}
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
@staticmethod
|
||||
def _embedding_invoke(model: str, client: AzureOpenAI, texts: list[str],
|
||||
extra_model_kwargs: dict) -> Tuple[list[list[float]], int]:
|
||||
response = client.embeddings.create(
|
||||
input=texts,
|
||||
model=model,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64':
|
||||
# decode base64 embedding
|
||||
return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
|
||||
response.usage.total_tokens)
|
||||
|
||||
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in EMBEDDING_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity.entity.model = model
|
||||
ai_model_entity.entity.label.en_US = model
|
||||
ai_model_entity.entity.label.zh_Hans = model
|
||||
return ai_model_entity
|
||||
|
||||
return None
|
||||
@ -0,0 +1,19 @@
|
||||
<svg width="130" height="24" viewBox="0 0 130 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M9.58154 1.7793H6.52779L4.34655 6.20409V17.7335L1.91602 22.2206H7.21333L9.58154 17.7335V1.7793ZM11.5761 1.7793H16.8111V22.2206H11.5761V1.7793ZM23.9166 1.7793H18.6816V6.01712H23.9166V1.7793ZM23.9166 7.38818H18.6816V22.2206H23.9166V7.38818Z" fill="url(#paint0_radial_11622_96091)"/>
|
||||
<path d="M129.722 6.83203V18H127.482V6.83203H129.722Z" fill="#FF6A34"/>
|
||||
<path d="M123.196 15.872H118.748L118.012 18H115.66L119.676 6.81604H122.284L126.3 18H123.932L123.196 15.872ZM122.588 14.08L120.972 9.40804L119.356 14.08H122.588Z" fill="#FF6A34"/>
|
||||
<path d="M110.962 18H108.722L103.65 10.336V18H101.41V6.81598H103.65L108.722 14.496V6.81598H110.962V18Z" fill="#FF6A34"/>
|
||||
<path d="M97.1258 15.872H92.6778L91.9418 18H89.5898L93.6058 6.81604H96.2138L100.23 18H97.8618L97.1258 15.872ZM96.5178 14.08L94.9018 9.40804L93.2858 14.08H96.5178Z" fill="#FF6A34"/>
|
||||
<path d="M81.6482 6.83203V13.744C81.6482 14.5014 81.8455 15.0827 82.2402 15.488C82.6349 15.8827 83.1895 16.08 83.9042 16.08C84.6295 16.08 85.1895 15.8827 85.5842 15.488C85.9789 15.0827 86.1762 14.5014 86.1762 13.744V6.83203H88.4322V13.728C88.4322 14.6774 88.2242 15.4827 87.8082 16.144C87.4029 16.7947 86.8535 17.2854 86.1602 17.616C85.4775 17.9467 84.7149 18.112 83.8722 18.112C83.0402 18.112 82.2829 17.9467 81.6002 17.616C80.9282 17.2854 80.3949 16.7947 80.0002 16.144C79.6055 15.4827 79.4082 14.6774 79.4082 13.728V6.83203H81.6482Z" fill="#FF6A34"/>
|
||||
<path d="M77.557 6.83203V18H75.317V13.248H70.533V18H68.293V6.83203H70.533V11.424H75.317V6.83203H77.557Z" fill="#FF6A34"/>
|
||||
<path d="M55.7871 12.4C55.7871 11.3013 56.0324 10.32 56.5231 9.45599C57.0244 8.58132 57.7018 7.90399 58.5551 7.42399C59.4191 6.93332 60.3844 6.68799 61.4511 6.68799C62.6991 6.68799 63.7924 7.00799 64.7311 7.64799C65.6698 8.28799 66.3258 9.17332 66.6991 10.304H64.1231C63.8671 9.77065 63.5044 9.37065 63.0351 9.10399C62.5764 8.83732 62.0431 8.70399 61.4351 8.70399C60.7844 8.70399 60.2031 8.85865 59.6911 9.16799C59.1898 9.46665 58.7951 9.89332 58.5071 10.448C58.2298 11.0027 58.0911 11.6533 58.0911 12.4C58.0911 13.136 58.2298 13.7867 58.5071 14.352C58.7951 14.9067 59.1898 15.3387 59.6911 15.648C60.2031 15.9467 60.7844 16.096 61.4351 16.096C62.0431 16.096 62.5764 15.9627 63.0351 15.696C63.5044 15.4187 63.8671 15.0133 64.1231 14.48H66.6991C66.3258 15.6213 65.6698 16.512 64.7311 17.152C63.8031 17.7813 62.7098 18.096 61.4511 18.096C60.3844 18.096 59.4191 17.856 58.5551 17.376C57.7018 16.8853 57.0244 16.208 56.5231 15.344C56.0324 14.48 55.7871 13.4987 55.7871 12.4Z" fill="#FF6A34"/>
|
||||
<path d="M54.4373 6.83203V18H52.1973V6.83203H54.4373Z" fill="#FF6A34"/>
|
||||
<path d="M47.913 15.872H43.465L42.729 18H40.377L44.393 6.81598H47.001L51.017 18H48.649L47.913 15.872ZM47.305 14.08L45.689 9.40798L44.073 14.08H47.305Z" fill="#FF6A34"/>
|
||||
<path d="M37.4395 12.272C38.0688 12.3893 38.5862 12.704 38.9915 13.216C39.3968 13.728 39.5995 14.3146 39.5995 14.976C39.5995 15.5733 39.4502 16.1013 39.1515 16.56C38.8635 17.008 38.4422 17.36 37.8875 17.616C37.3328 17.872 36.6768 18 35.9195 18H31.1035V6.83197H35.7115C36.4688 6.83197 37.1195 6.95464 37.6635 7.19997C38.2182 7.4453 38.6342 7.78664 38.9115 8.22397C39.1995 8.6613 39.3435 9.1573 39.3435 9.71197C39.3435 10.3626 39.1675 10.9066 38.8155 11.344C38.4742 11.7813 38.0155 12.0906 37.4395 12.272ZM33.3435 11.44H35.3915C35.9248 11.44 36.3355 11.3226 36.6235 11.088C36.9115 10.8426 37.0555 10.496 37.0555 10.048C37.0555 9.59997 36.9115 9.2533 36.6235 9.00797C36.3355 8.76264 35.9248 8.63997 35.3915 8.63997H33.3435V11.44ZM35.5995 16.176C36.1435 16.176 36.5648 16.048 36.8635 15.792C37.1728 15.536 37.3275 15.1733 37.3275 14.704C37.3275 14.224 37.1675 13.8506 36.8475 13.584C36.5275 13.3066 36.0955 13.168 35.5515 13.168H33.3435V16.176H35.5995Z" fill="#FF6A34"/>
|
||||
<defs>
|
||||
<radialGradient id="paint0_radial_11622_96091" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(6.5 5.5) rotate(45) scale(20.5061 22.0704)">
|
||||
<stop stop-color="#FEBD3F"/>
|
||||
<stop offset="0.77608" stop-color="#FF6933"/>
|
||||
</radialGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.0 KiB |
@ -0,0 +1,11 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="Baichuan">
|
||||
<path id="Union" fill-rule="evenodd" clip-rule="evenodd" d="M8.58154 1.7793H5.52779L3.34655 6.20409V17.7335L0.916016 22.2206H6.21333L8.58154 17.7335V1.7793ZM10.5761 1.7793H15.8111V22.2206H10.5761V1.7793ZM22.9166 1.7793H17.6816V6.01712H22.9166V1.7793ZM22.9166 7.38818H17.6816V22.2206H22.9166V7.38818Z" fill="url(#paint0_radial_11622_96084)"/>
|
||||
</g>
|
||||
<defs>
|
||||
<radialGradient id="paint0_radial_11622_96084" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(5.5 5.5) rotate(45) scale(20.5061 22.0704)">
|
||||
<stop stop-color="#FEBD3F"/>
|
||||
<stop offset="0.77608" stop-color="#FF6933"/>
|
||||
</radialGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 748 B |
29
api/core/model_runtime/model_providers/baichuan/baichuan.py
Normal file
@ -0,0 +1,29 @@
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaichuanProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `baichuan2-turbo` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='baichuan2-turbo',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@ -0,0 +1,37 @@
|
||||
provider: baichuan
|
||||
label:
|
||||
en_US: Baichuan
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#FFF6F2"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from BAICHUAN AI
|
||||
zh_Hans: 从百川智能获取您的 API Key
|
||||
url:
|
||||
en_US: https://www.baichuan-ai.com
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: secret_key
|
||||
label:
|
||||
en_US: Secret Key
|
||||
type: secret-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 Secret Key
|
||||
en_US: Enter your Secret Key
|
||||
@ -0,0 +1,42 @@
|
||||
model: baichuan2-53b
|
||||
label:
|
||||
en_US: Baichuan2-53B
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 1000
|
||||
min: 1
|
||||
max: 4000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
en_US: Search Enhance
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
@ -0,0 +1,42 @@
|
||||
model: baichuan2-turbo-192k
|
||||
label:
|
||||
en_US: Baichuan2-Turbo-192K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 192000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 192000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
en_US: Search Enhance
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
@ -0,0 +1,42 @@
|
||||
model: baichuan2-turbo
|
||||
label:
|
||||
en_US: Baichuan2-Turbo
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 192000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 192000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
en_US: Search Enhance
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
@ -0,0 +1,19 @@
|
||||
import re
|
||||
|
||||
class BaichuanTokenizer(object):
|
||||
@classmethod
|
||||
def count_chinese_characters(cls, text: str) -> int:
|
||||
return len(re.findall(r'[\u4e00-\u9fa5]', text))
|
||||
|
||||
@classmethod
|
||||
def count_english_vocabularies(cls, text: str) -> int:
|
||||
# remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc.
|
||||
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
|
||||
# count the number of words not characters
|
||||
return len(text.split())
|
||||
|
||||
@classmethod
|
||||
def _get_num_tokens(cls, text: str) -> int:
|
||||
# tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return)
|
||||
# https://platform.baichuan-ai.com/docs/text-Embedding
|
||||
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)
|
||||
@ -0,0 +1,199 @@
|
||||
from os.path import join
|
||||
from typing import List, Optional, Generator, Union, Dict, Any
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import \
|
||||
InsufficientAccountBalance, InvalidAPIKeyError, InternalServerError, RateLimitReachedError, InvalidAuthenticationError, BadRequestError
|
||||
from enum import Enum
|
||||
from json import dumps, loads
|
||||
from requests import post
|
||||
from time import time
|
||||
from hashlib import md5
|
||||
|
||||
class BaichuanMessage:
|
||||
class Role(Enum):
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
# Baichuan does not have system message
|
||||
_SYSTEM = 'system'
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: Dict[str, int] = None
|
||||
stop_reason: str = ''
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'role': self.role,
|
||||
'content': self.content,
|
||||
}
|
||||
|
||||
def __init__(self, content: str, role: str = 'user') -> None:
|
||||
self.content = content
|
||||
self.role = role
|
||||
|
||||
class BaichuanModel(object):
|
||||
api_key: str
|
||||
secret_key: str
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str = '') -> None:
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
def _model_mapping(self, model: str) -> str:
|
||||
return {
|
||||
'baichuan2-turbo': 'Baichuan2-Turbo',
|
||||
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
|
||||
'baichuan2-53b': 'Baichuan2-53B',
|
||||
}[model]
|
||||
|
||||
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
|
||||
resp = response.json()
|
||||
choices = resp.get('choices', [])
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
for choice in choices:
|
||||
message.content += choice['message']['content']
|
||||
message.role = choice['message']['role']
|
||||
if choice['finish_reason']:
|
||||
message.stop_reason = choice['finish_reason']
|
||||
|
||||
if 'usage' in resp:
|
||||
message.usage = {
|
||||
'prompt_tokens': resp['usage']['prompt_tokens'],
|
||||
'completion_tokens': resp['usage']['completion_tokens'],
|
||||
'total_tokens': resp['usage']['total_tokens'],
|
||||
}
|
||||
|
||||
return message
|
||||
|
||||
def _handle_chat_stream_generate_response(self, response) -> Generator:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode('utf-8')
|
||||
# remove the first `data: ` prefix
|
||||
if line.startswith('data:'):
|
||||
line = line[5:].strip()
|
||||
try:
|
||||
data = loads(line)
|
||||
except Exception as e:
|
||||
if line.strip() == '[DONE]':
|
||||
return
|
||||
choices = data.get('choices', [])
|
||||
# save stop reason temporarily
|
||||
stop_reason = ''
|
||||
for choice in choices:
|
||||
if 'finish_reason' in choice and choice['finish_reason']:
|
||||
stop_reason = choice['finish_reason']
|
||||
|
||||
if len(choice['delta']['content']) == 0:
|
||||
continue
|
||||
yield BaichuanMessage(**choice['delta'])
|
||||
|
||||
# if there is usage, the response is the last one, yield it and return
|
||||
if 'usage' in data:
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
message.usage = {
|
||||
'prompt_tokens': data['usage']['prompt_tokens'],
|
||||
'completion_tokens': data['usage']['completion_tokens'],
|
||||
'total_tokens': data['usage']['total_tokens'],
|
||||
}
|
||||
message.stop_reason = stop_reason
|
||||
yield message
|
||||
|
||||
def _build_parameters(self, model: str, stream: bool, messages: List[BaichuanMessage],
|
||||
parameters: Dict[str, Any]) \
|
||||
-> Dict[str, Any]:
|
||||
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
|
||||
# check if the latest message is a user message
|
||||
if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
|
||||
prompt_messages[-1]['content'] += message.content
|
||||
else:
|
||||
prompt_messages.append({
|
||||
'content': message.content,
|
||||
'role': BaichuanMessage.Role.USER.value,
|
||||
})
|
||||
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
|
||||
prompt_messages.append({
|
||||
'content': message.content,
|
||||
'role': message.role,
|
||||
})
|
||||
# turbo api accepts flat parameters
|
||||
return {
|
||||
'model': self._model_mapping(model),
|
||||
'stream': stream,
|
||||
'messages': prompt_messages,
|
||||
**parameters,
|
||||
}
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _build_headers(self, model: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
||||
# there is no secret key for turbo api
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ',
|
||||
'Authorization': 'Bearer ' + self.api_key,
|
||||
}
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _calculate_md5(self, input_string):
|
||||
return md5(input_string.encode('utf-8')).hexdigest()
|
||||
|
||||
def generate(self, model: str, stream: bool, messages: List[BaichuanMessage],
|
||||
parameters: Dict[str, Any], timeout: int) \
|
||||
-> Union[Generator, BaichuanMessage]:
|
||||
|
||||
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
||||
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
try:
|
||||
data = self._build_parameters(model, stream, messages, parameters)
|
||||
headers = self._build_headers(model, data)
|
||||
except KeyError:
|
||||
raise InternalServerError(f"Failed to build parameters for model: {model}")
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=dumps(data),
|
||||
timeout=timeout,
|
||||
stream=stream
|
||||
)
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to invoke model: {e}")
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
# try to parse error message
|
||||
err = resp['error']['code']
|
||||
msg = resp['error']['message']
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
if err == 'invalid_api_key':
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == 'insufficient_quota':
|
||||
raise InsufficientAccountBalance(msg)
|
||||
elif err == 'invalid_authentication':
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif 'rate' in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif 'internal' in err:
|
||||
raise InternalServerError(msg)
|
||||
elif err == 'api_key_empty':
|
||||
raise InvalidAPIKeyError(msg)
|
||||
else:
|
||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_stream_generate_response(response)
|
||||
else:
|
||||
return self._handle_chat_generate_response(response)
|
||||
@ -0,0 +1,17 @@
|
||||
class InvalidAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
class InvalidAPIKeyError(Exception):
|
||||
pass
|
||||
|
||||
class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
pass
|
||||
|
||||
class InternalServerError(Exception):
|
||||
pass
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
||||
194
api/core/model_runtime/model_providers/baichuan/llm/llm.py
Normal file
@ -0,0 +1,194 @@
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import \
|
||||
InsufficientAccountBalance, InvalidAPIKeyError, InternalServerError, RateLimitReachedError, InvalidAuthenticationError, BadRequestError
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel, BaichuanMessage
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||
|
||||
class BaichuanLarguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None,
|
||||
stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
|
||||
def _num_tokens_from_messages(self, messages: List[PromptMessage],) -> int:
|
||||
"""Calculate num tokens for baichuan model"""
|
||||
def tokens(text: str):
|
||||
return BaichuanTokenizer._get_num_tokens(text)
|
||||
|
||||
tokens_per_message = 3
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
|
||||
num_tokens += tokens(str(value))
|
||||
num_tokens += 3
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for Baichuan
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError("User message content must be str")
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
# ping
|
||||
instance = BaichuanModel(
|
||||
api_key=credentials['api_key'],
|
||||
secret_key=credentials.get('secret_key', '')
|
||||
)
|
||||
|
||||
try:
|
||||
instance.generate(model=model, stream=False, messages=[
|
||||
BaichuanMessage(content='ping', role='user')
|
||||
], parameters={}, timeout=10)
|
||||
except InvalidAPIKeyError as e:
|
||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: List[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
if tools is not None and len(tools) > 0:
|
||||
raise InvokeBadRequestError(f"Baichuan model doesn't support tools")
|
||||
|
||||
instance = BaichuanModel(
|
||||
api_key=credentials['api_key'],
|
||||
secret_key=credentials.get('secret_key', '')
|
||||
)
|
||||
|
||||
# convert prompt messages to baichuan messages
|
||||
messages = [
|
||||
BaichuanMessage(
|
||||
content=message.content if isinstance(message.content, str) else ''.join([
|
||||
content.data for content in message.content
|
||||
]),
|
||||
role=message.role.value
|
||||
) for message in prompt_messages
|
||||
]
|
||||
|
||||
# invoke model
|
||||
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
||||
|
||||
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: BaichuanMessage) -> LLMResult:
|
||||
# convert baichuan message to llm result
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens'])
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=response.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Generator[BaichuanMessage, None, None]) -> Generator:
|
||||
for message in response:
|
||||
if message.usage:
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens'])
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
usage=usage,
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
@ -0,0 +1,5 @@
|
||||
model: baichuan-text-embedding
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
max_chunks: 16
|
||||
@ -0,0 +1,178 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.errors.invoke import InvokeError, InvokeConnectionError, InvokeServerUnavailableError, \
|
||||
InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import InvalidAPIKeyError, InsufficientAccountBalance, \
|
||||
InvalidAuthenticationError, RateLimitReachedError, InternalServerError, BadRequestError
|
||||
|
||||
from requests import post
|
||||
from json import dumps, loads
|
||||
|
||||
import time
|
||||
|
||||
class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for BaiChuan text embedding model.
|
||||
"""
|
||||
api_base: str = 'http://api.baichuan-ai.com/v1/embeddings'
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials['api_key']
|
||||
if model != 'baichuan-text-embedding':
|
||||
raise ValueError('Invalid model name')
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError('api_key is required')
|
||||
url = self.api_base
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
data = {
|
||||
'model': 'Baichuan-Text-Embedding',
|
||||
'input': texts
|
||||
}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(e)
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
# try to parse error message
|
||||
err = resp['error']['code']
|
||||
msg = resp['error']['message']
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
if err == 'invalid_api_key':
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == 'insufficient_quota':
|
||||
raise InsufficientAccountBalance(msg)
|
||||
elif err == 'invalid_authentication':
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif 'rate' in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif 'internal' in err:
|
||||
raise InternalServerError(msg)
|
||||
elif err == 'api_key_empty':
|
||||
raise InvalidAPIKeyError(msg)
|
||||
else:
|
||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp['data']
|
||||
usage = resp['usage']
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=[[
|
||||
float(data) for data in x['embedding']
|
||||
] for x in embeddings],
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
# use BaichuanTokenizer to get num tokens
|
||||
num_tokens += BaichuanTokenizer._get_num_tokens(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
except InvalidAPIKeyError:
|
||||
raise CredentialsValidateFailedError('Invalid api key')
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
After Width: | Height: | Size: 5.4 KiB |
@ -0,0 +1,9 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<mask id="mask0_8587_60212" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="1" y="2" width="23" height="21">
|
||||
<path d="M23.8 2H1V22.4H23.8V2Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask0_8587_60212)">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.86378 14.4544C3.86378 13.0981 4.67438 11.737 6.25923 10.6634C7.83827 9.59364 10.0864 8.89368 12.6282 8.89368C15.17 8.89368 17.4182 9.59364 18.9972 10.6634C19.7966 11.2049 20.399 11.8196 20.7998 12.4699C21.2873 11.5802 21.4969 10.6351 21.3835 9.69252C21.3759 9.62928 21.3824 9.56766 21.4005 9.5106C21.0758 9.21852 20.7259 8.94624 20.3558 8.69556C18.3272 7.32126 15.5915 6.50964 12.6282 6.50964C9.66497 6.50964 6.92918 7.32126 4.90058 8.69556C2.8778 10.0659 1.45703 12.0812 1.45703 14.4544C1.45703 16.8275 2.8778 18.8428 4.90058 20.2132C6.92918 21.5875 9.66497 22.3991 12.6282 22.3991C15.5915 22.3991 18.3272 21.5875 20.3558 20.2132C22.3786 18.8428 23.7994 16.8275 23.7994 14.4544C23.7994 12.9455 23.225 11.5813 22.2868 10.4355C22.2377 11.4917 21.8621 12.5072 21.238 13.43C21.3409 13.7686 21.3926 14.1116 21.3926 14.4544C21.3926 15.8107 20.582 17.1717 18.9972 18.2453C17.4182 19.3151 15.17 20.015 12.6282 20.015C10.0864 20.015 7.83827 19.3151 6.25923 18.2453C4.67438 17.1717 3.86378 15.8107 3.86378 14.4544Z" fill="#3762FF"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.84445 11.6838C3.20239 13.4885 3.35368 15.1156 4.18868 16.2838C5.02368 17.452 6.52281 18.1339 8.45459 18.1334C10.3826 18.133 12.6296 17.44 14.6939 15.9922C16.7581 14.5444 18.1643 12.6753 18.8052 10.8739C19.4473 9.0692 19.2959 7.44206 18.461 6.27392C17.626 5.10572 16.1269 4.42389 14.1951 4.42431C12.267 4.42475 10.0201 5.11774 7.95575 6.56552C5.89152 8.01332 4.48529 9.8825 3.84445 11.6838ZM1.53559 10.8778C2.36374 8.55002 4.11254 6.28976 6.54117 4.58645C8.96981 2.88312 11.7029 1.99995 14.1945 1.99939C16.6825 1.99884 19.0426 2.8912 20.4589 4.87263C21.8752 6.85406 21.941 9.35564 21.1141 11.6799C20.2859 14.0077 18.5371 16.2679 16.1085 17.9713C13.6798 19.6746 10.9468 20.5578 8.45513 20.5584C5.9672 20.5589 3.60706 19.6665 2.19075 17.6851C0.774446 15.7036 0.708677 13.2021 1.53559 10.8778Z" fill="#1041F3"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.2 KiB |
31
api/core/model_runtime/model_providers/chatglm/chatglm.py
Normal file
@ -0,0 +1,31 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGLMProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `chatglm3-6b` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='chatglm3-6b',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
28
api/core/model_runtime/model_providers/chatglm/chatglm.yaml
Normal file
@ -0,0 +1,28 @@
|
||||
provider: chatglm
|
||||
label:
|
||||
en_US: ChatGLM
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#F4F7FF"
|
||||
help:
|
||||
title:
|
||||
en_US: Deploy ChatGLM to your local
|
||||
zh_Hans: 部署您的本地 ChatGLM
|
||||
url:
|
||||
en_US: https://github.com/THUDM/ChatGLM3
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_base
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API URL
|
||||
en_US: Enter your API URL
|
||||