Merge branch 'main' into feat/rag-2

This commit is contained in:
twwu
2025-08-21 09:43:51 +08:00
131 changed files with 3855 additions and 1081 deletions

View File

@ -3,7 +3,7 @@
## Usage
> [!IMPORTANT]
>
>
> In the v1.3.0 release, `poetry` has been replaced with
> [`uv`](https://docs.astral.sh/uv/) as the package manager
> for Dify API backend service.
@ -20,25 +20,29 @@
cd ../api
```
2. Copy `.env.example` to `.env`
1. Copy `.env.example` to `.env`
```cli
cp .env.example .env
cp .env.example .env
```
3. Generate a `SECRET_KEY` in the `.env` file.
1. Generate a `SECRET_KEY` in the `.env` file.
bash for Linux
```bash for Linux
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
```
bash for Mac
```bash for Mac
secret_key=$(openssl rand -base64 42)
sed -i '' "/^SECRET_KEY=/c\\
SECRET_KEY=${secret_key}" .env
```
4. Create environment.
1. Create environment.
Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
First, you need to add the uv package manager, if you don't have it already.
@ -49,13 +53,13 @@
brew install uv
```
5. Install dependencies
1. Install dependencies
```bash
uv sync --dev
```
6. Run migrate
1. Run migrate
Before the first launch, migrate the database to the latest version.
@ -63,24 +67,27 @@
uv run flask db upgrade
```
7. Start backend
1. Start backend
```bash
uv run flask run --host 0.0.0.0 --port=5001 --debug
```
8. Start Dify [web](../web) service.
9. Setup your application by visiting `http://localhost:3000`.
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
1. Start Dify [web](../web) service.
```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
```
1. Setup your application by visiting `http://localhost:3000`.
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
```bash
uv run celery -A app.celery beat
```
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
```bash
uv run celery -A app.celery beat
```
## Testing
@ -90,9 +97,8 @@
uv sync --dev
```
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```bash
uv run -P api bash dev/pytest/pytest_all_tests.sh
```

View File

@ -1,3 +1,5 @@
from datetime import datetime
import pytz
from flask import request
from flask_login import current_user
@ -327,6 +329,9 @@ class EducationVerifyApi(Resource):
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
"is_student": fields.Boolean,
"expire_at": TimestampField,
"allow_refresh": fields.Boolean,
}
@setup_required
@ -354,7 +359,11 @@ class EducationApi(Resource):
def get(self):
account = current_user
return BillingService.EducationIdentity.is_active(account.id)
res = BillingService.EducationIdentity.status(account.id)
# convert expire_at to UTC timestamp from isoformat
if res and "expire_at" in res:
res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc)
return res
class EducationAutoCompleteApi(Resource):

View File

@ -1,27 +1,38 @@
from flask_restful import (
Resource, # type: ignore
reqparse,
)
from flask_restful import Resource, reqparse
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.wraps import enterprise_inner_api_only
from services.enterprise.mail_service import DifyMail, EnterpriseMailService
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
from tasks.mail_inner_task import send_inner_email_task
_mail_parser = reqparse.RequestParser()
_mail_parser.add_argument("to", type=str, action="append", required=True)
_mail_parser.add_argument("subject", type=str, required=True)
_mail_parser.add_argument("body", type=str, required=True)
_mail_parser.add_argument("substitutions", type=dict, required=False)
class EnterpriseMail(Resource):
@setup_required
@enterprise_inner_api_only
class BaseMail(Resource):
"""Shared logic for sending an inner email."""
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("to", type=str, action="append", required=True)
parser.add_argument("subject", type=str, required=True)
parser.add_argument("body", type=str, required=True)
parser.add_argument("substitutions", type=dict, required=False)
args = parser.parse_args()
EnterpriseMailService.send_mail(DifyMail(**args))
args = _mail_parser.parse_args()
send_inner_email_task.delay(
to=args["to"],
subject=args["subject"],
body=args["body"],
substitutions=args["substitutions"],
)
return {"message": "success"}, 200
class EnterpriseMail(BaseMail):
method_decorators = [setup_required, enterprise_inner_api_only]
class BillingMail(BaseMail):
method_decorators = [setup_required, billing_inner_api_only]
api.add_resource(EnterpriseMail, "/enterprise/mail")
api.add_resource(BillingMail, "/billing/mail")

View File

@ -10,6 +10,22 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.INNER_API:
abort(404)
# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
abort(401)
return view(*args, **kwargs)
return decorated
def enterprise_inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):

View File

@ -30,7 +30,7 @@ from core.rag.splitter.fixed_text_splitter import (
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage

View File

@ -532,7 +532,7 @@ class LLMGenerator:
model=model_config.get("name", ""),
)
match node_type:
case "llm", "agent":
case "llm" | "agent":
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
case "code":
system_prompt = LLM_MODIFY_CODE_SYSTEM

View File

@ -7,7 +7,7 @@ import urllib.parse
from typing import Optional
from urllib.parse import urljoin
import requests
import httpx
from pydantic import BaseModel, ValidationError
from core.mcp.auth.auth_provider import OAuthClientProvider
@ -105,18 +105,18 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
try:
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
response = requests.get(url, headers=headers)
response = httpx.get(url, headers=headers)
if response.status_code == 404:
return None
if not response.ok:
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
except requests.RequestException as e:
if isinstance(e, requests.ConnectionError):
response = requests.get(url)
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
if response.status_code == 404:
return None
if not response.ok:
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
raise
@ -206,8 +206,8 @@ def exchange_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = requests.post(token_url, data=params)
if not response.ok:
response = httpx.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
@ -237,8 +237,8 @@ def refresh_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = requests.post(token_url, data=params)
if not response.ok:
response = httpx.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
@ -256,12 +256,12 @@ def register_client(
else:
registration_url = urljoin(server_url, "/register")
response = requests.post(
response = httpx.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
if not response.ok:
if not response.is_success:
response.raise_for_status()
return OAuthClientInformationFull.model_validate(response.json())
@ -283,7 +283,7 @@ def auth(
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
try:
full_information = register_client(server_url, metadata, provider.client_metadata)
except requests.RequestException as e:
except httpx.RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
provider.save_client_information(full_information)
client_information = full_information

View File

@ -30,7 +30,7 @@ This module provides the interface for invoking and authenticating various model
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png)
![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png)
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
@ -60,8 +60,6 @@ Model Runtime is divided into three layers:
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
## Next Steps
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)

View File

@ -20,19 +20,19 @@
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
- 可选择的模型列表展示
![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png)
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png)
![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png)
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
- 供应商/模型凭据鉴权
@ -40,7 +40,7 @@
![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png)
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO上图 2 为模型凭据 DEMO。
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO上图 2 为模型凭据 DEMO。
## 结构
@ -57,9 +57,10 @@ Model Runtime 分三层:
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
对于供应商/模型凭据,有两种情况
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
![Alt text](docs/zh_Hans/images/index/image.png)
![Alt text](docs/zh_Hans/images/index/image.png)
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
@ -76,14 +77,17 @@ Model Runtime 分三层:
## 下一步
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
当添加后,这里将会出现一个新的供应商
![Alt text](docs/zh_Hans/images/index/image-1.png)
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。

View File

@ -56,7 +56,6 @@ provider_credential_schema:
credential_form_schemas:
```
Then, we need to determine what credentials are required to define a model in Xinference.
- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
@ -191,7 +190,6 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr
"""
```
Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens and ensure environment variable `PLUGIN_BASED_TOKEN_COUNTING_ENABLED` is set to `true`, This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
- Model Credentials Validation

View File

@ -35,12 +35,11 @@ All models need to uniformly implement the following 2 methods:
Similar to provider credential verification, this step involves verification for an individual model.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
@ -77,12 +76,12 @@ All models need to uniformly implement the following 2 methods:
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
You can refer to OpenAI's `_invoke_error_mapping` for an example.
You can refer to OpenAI's `_invoke_error_mapping` for an example.
### LLM
@ -92,7 +91,6 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
@ -101,7 +99,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
@ -122,7 +120,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `prompt_messages` (array[[PromptMessage](#PromptMessage)]) List of prompts
- `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) List of prompts
If the model is of the `Completion` type, the list only needs to include one [UserPromptMessage](#UserPromptMessage) element;
@ -132,7 +130,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
The model parameters are defined by the `parameter_rules` in the model's YAML configuration.
- `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] List of tools, equivalent to the `function` in `function calling`.
- `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] List of tools, equivalent to the `function` in `function calling`.
That is, the tool list for tool calling.
@ -142,7 +140,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
- `stream` (bool) Whether to output in a streaming manner, default is True
Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult).
Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult).
- `user` (string) [optional] Unique identifier of the user
@ -150,7 +148,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
- Returns
Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult).
Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult).
- Pre-calculating Input Tokens
@ -187,7 +185,6 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
When the provider supports adding custom LLMs, this method can be implemented to allow custom models to fetch model schema. The default return null.
### TextEmbedding
Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and implement the following interfaces:
@ -200,7 +197,7 @@ Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and impl
-> TextEmbeddingResult:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
@ -256,7 +253,7 @@ Inherit the `__base.rerank_model.RerankModel` base class and implement the follo
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
@ -302,7 +299,7 @@ Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param file: audio file
@ -339,7 +336,7 @@ Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
@ -381,7 +378,7 @@ Inherit the `__base.moderation_model.ModerationModel` base class and implement t
-> bool:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
@ -408,11 +405,9 @@ Inherit the `__base.moderation_model.ModerationModel` base class and implement t
False indicates that the input text is safe, True indicates otherwise.
## Entities
### PromptMessageRole
### PromptMessageRole
Message role
@ -583,7 +578,7 @@ class PromptMessageTool(BaseModel):
parameters: dict
```
---
______________________________________________________________________
### LLMResult
@ -650,7 +645,7 @@ class LLMUsage(ModelUsage):
latency: float # Request latency (s)
```
---
______________________________________________________________________
### TextEmbeddingResult
@ -680,7 +675,7 @@ class EmbeddingUsage(ModelUsage):
latency: float # Request latency (s)
```
---
______________________________________________________________________
### RerankResult

View File

@ -153,8 +153,11 @@ Runtime Errors:
- `InvokeConnectionError` Connection error
- `InvokeServerUnavailableError` Service provider unavailable
- `InvokeRateLimitError` Rate limit reached
- `InvokeAuthorizationError` Authorization failed
- `InvokeBadRequestError` Parameter error
```python

View File

@ -63,6 +63,7 @@ You can also refer to the YAML configuration information under other provider di
### Implementing Provider Code
Providers need to inherit the `__base.model_provider.ModelProvider` base class and implement the `validate_provider_credentials` method for unified provider credential verification. For reference, see [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py).
> If the provider is the type of `customizable-model`, there is no need to implement the `validate_provider_credentials` method.
```python
@ -80,7 +81,7 @@ def validate_provider_credentials(self, credentials: dict) -> None:
Of course, you can also preliminarily reserve the implementation of `validate_provider_credentials` and directly reuse it after the model credential verification method is implemented.
---
______________________________________________________________________
### Adding Models
@ -166,7 +167,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
@ -205,7 +206,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
@ -232,7 +233,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```

View File

@ -28,8 +28,8 @@
- `url` (object) help link, i18n
- `zh_Hans` (string) [optional] Chinese link
- `en_US` (string) English link
- `supported_model_types` (array[[ModelType](#ModelType)]) Supported model types
- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) Configuration methods
- `supported_model_types` (array\[[ModelType](#ModelType)\]) Supported model types
- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) Configuration methods
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) Provider credential specification
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) Model credential specification
@ -40,23 +40,23 @@
- `zh_Hans` (string) [optional] Chinese label name
- `en_US` (string) English label name
- `model_type` ([ModelType](#ModelType)) Model type
- `features` (array[[ModelFeature](#ModelFeature)]) [optional] Supported feature list
- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] Supported feature list
- `model_properties` (object) Model properties
- `mode` ([LLMMode](#LLMMode)) Mode (available for model type `llm`)
- `context_size` (int) Context size (available for model types `llm`, `text-embedding`)
- `max_chunks` (int) Maximum number of chunks (available for model types `text-embedding`, `moderation`)
- `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`)
- `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`)
- `default_voice` (string) default voice, e.g.alloy,echo,fable,onyx,nova,shimmeravailable for model type `tts`
- `voices` (list) List of available voice.available for model type `tts`
- `mode` (string) voice model.available for model type `tts`
- `name` (string) voice model display name.available for model type `tts`
- `language` (string) the voice model supports languages.available for model type `tts`
- `word_limit` (int) Single conversion word limit, paragraph-wise by defaultavailable for model type `tts`
- `audio_type` (string) Support audio file extension format, e.g.mp3,wavavailable for model type `tts`
- `max_workers` (int) Number of concurrent workers supporting text and audio conversionavailable for model type`tts`
- `default_voice` (string) default voice, e.g.alloy,echo,fable,onyx,nova,shimmeravailable for model type `tts`
- `voices` (list) List of available voice.available for model type `tts`
- `mode` (string) voice model.available for model type `tts`
- `name` (string) voice model display name.available for model type `tts`
- `language` (string) the voice model supports languages.available for model type `tts`
- `word_limit` (int) Single conversion word limit, paragraph-wise by defaultavailable for model type `tts`
- `audio_type` (string) Support audio file extension format, e.g.mp3,wavavailable for model type `tts`
- `max_workers` (int) Number of concurrent workers supporting text and audio conversionavailable for model type`tts`
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] Model invocation parameter rules
- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] Model invocation parameter rules
- `pricing` ([PriceConfig](#PriceConfig)) [optional] Pricing information
- `deprecated` (bool) Whether deprecated. If deprecated, the model will no longer be displayed in the list, but those already configured can continue to be used. Default False.
@ -74,6 +74,7 @@
- `predefined-model` Predefined model
Indicates that users can use the predefined models under the provider by configuring the unified provider credentials.
- `customizable-model` Customizable model
Users need to add credential configuration for each model.
@ -103,6 +104,7 @@
### ParameterRule
- `name` (string) Actual model invocation parameter name
- `use_template` (string) [optional] Using template
By default, 5 variable content configuration templates are preset:
@ -112,7 +114,7 @@
- `frequency_penalty`
- `presence_penalty`
- `max_tokens`
In use_template, you can directly set the template variable name, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
No need to set any parameters other than `name` and `use_template`. If additional configuration parameters are set, they will override the default configuration.
Refer to `openai/llm/gpt-3.5-turbo.yaml`.
@ -155,7 +157,7 @@
### ProviderCredentialSchema
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard
### ModelCredentialSchema
@ -166,7 +168,7 @@
- `placeholder` (object) Model prompt content
- `en_US`(string) English
- `zh_Hans`(string) [optional] Chinese
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard
### CredentialFormSchema
@ -177,12 +179,12 @@
- `type` ([FormType](#FormType)) Form item type
- `required` (bool) Whether required
- `default`(string) Default value
- `options` (array[[FormOption](#FormOption)]) Specific property of form items of type `select` or `radio`, defining dropdown content
- `options` (array\[[FormOption](#FormOption)\]) Specific property of form items of type `select` or `radio`, defining dropdown content
- `placeholder`(object) Specific property of form items of type `text-input`, placeholder content
- `en_US`(string) English
- `zh_Hans` (string) [optional] Chinese
- `max_length` (int) Specific property of form items of type `text-input`, defining maximum input length, 0 for no limit.
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty.
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty.
### FormType
@ -198,7 +200,7 @@
- `en_US`(string) English
- `zh_Hans`(string) [optional] Chinese
- `value` (string) Dropdown option value
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty.
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty.
### FormShowOnObject

View File

@ -10,7 +10,6 @@
![Alt text](images/index/image-3.png)
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`Runtime 会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
### 编写供应商 yaml
@ -55,6 +54,7 @@ provider_credential_schema:
随后,我们需要思考在 Xinference 中定义一个模型需要哪些凭据
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
```yaml
provider_credential_schema:
credential_form_schemas:
@ -76,7 +76,9 @@ provider_credential_schema:
label:
en_US: Rerank
```
- 每一个模型都有自己的名称`model_name`,因此需要在这里定义
```yaml
- variable: model_name
type: text-input
@ -88,7 +90,9 @@ provider_credential_schema:
zh_Hans: 填写模型名称
en_US: Input model name
```
- 填写 Xinference 本地部署的地址
```yaml
- variable: server_url
label:
@ -100,7 +104,9 @@ provider_credential_schema:
zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
```
- 每个模型都有唯一的 model_uid因此需要在这里定义
```yaml
- variable: model_uid
label:
@ -112,6 +118,7 @@ provider_credential_schema:
zh_Hans: 在此输入您的 Model UID
en_US: Enter the model uid
```
现在,我们就完成了供应商的基础定义。
### 编写模型代码
@ -132,7 +139,7 @@ provider_credential_schema:
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
@ -189,7 +196,7 @@ provider_credential_schema:
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
@ -197,78 +204,78 @@ provider_credential_schema:
```
- 模型参数 Schema
与自定义类型不同,由于没有在 yaml 文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的 Schema。
如 Xinference 支持`max_tokens` `temperature` `top_p` 这三个模型参数。
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例 A 模型支持`top_k`B 模型不支持`top_k`,那么我们需要在这里动态生成模型参数的 Schema如下所示
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature', type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度', en_US='Temperature'
)
),
ParameterRule(
name='top_p', type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P', en_US='Top P'
)
),
ParameterRule(
name='max_tokens', type=ParameterType.INT,
use_template='max_tokens',
min=1,
default=512,
label=I18nObject(
zh_Hans='最大生成长度', en_US='Max Tokens'
)
)
]
# if model is A, add top_k to rules
if model == 'A':
rules.append(
ParameterRule(
name='top_k', type=ParameterType.INT,
use_template='top_k',
min=1,
default=50,
label=I18nObject(
zh_Hans='Top K', en_US='Top K'
)
)
)
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature', type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度', en_US='Temperature'
)
),
ParameterRule(
name='top_p', type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P', en_US='Top P'
)
),
ParameterRule(
name='max_tokens', type=ParameterType.INT,
use_template='max_tokens',
min=1,
default=512,
label=I18nObject(
zh_Hans='最大生成长度', en_US='Max Tokens'
)
)
]
"""
some NOT IMPORTANT code here
"""
# if model is A, add top_k to rules
if model == 'A':
rules.append(
ParameterRule(
name='top_k', type=ParameterType.INT,
use_template='top_k',
min=1,
default=50,
label=I18nObject(
zh_Hans='Top K', en_US='Top K'
)
)
)
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=model_type,
model_properties={
ModelPropertyKey.MODE: ModelType.LLM,
},
parameter_rules=rules
)
"""
some NOT IMPORTANT code here
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=model_type,
model_properties={
ModelPropertyKey.MODE: ModelType.LLM,
},
parameter_rules=rules
)
return entity
```
return entity
```
- 调用异常错误映射表
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
@ -278,7 +285,7 @@ provider_credential_schema:
- `InvokeConnectionError` 调用连接错误
- `InvokeServerUnavailableError ` 调用服务方不可用
- `InvokeRateLimitError ` 调用达到限额
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeBadRequestError ` 调用传参有误
```python
@ -289,7 +296,7 @@ provider_credential_schema:
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```

View File

@ -49,7 +49,7 @@ class XinferenceProvider(Provider):
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
@ -75,7 +75,7 @@ class XinferenceProvider(Provider):
- `InvokeConnectionError` 调用连接错误
- `InvokeServerUnavailableError ` 调用服务方不可用
- `InvokeRateLimitError ` 调用达到限额
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeBadRequestError ` 调用传参有误
```python
@ -86,36 +86,36 @@ class XinferenceProvider(Provider):
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
也可以直接抛出对应 Errors并做如下定义这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError
],
}
```
可参考 OpenAI `_invoke_error_mapping`。
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError
],
}
```
可参考 OpenAI `_invoke_error_mapping`。
### LLM
@ -133,7 +133,7 @@ class XinferenceProvider(Provider):
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
@ -151,38 +151,38 @@ class XinferenceProvider(Provider):
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
- `prompt_messages` (array[[PromptMessage](#PromptMessage)]) Prompt 列表
- `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) Prompt 列表
若模型为 `Completion` 类型,则列表只需要传入一个 [UserPromptMessage](#UserPromptMessage) 元素即可;
若模型为 `Chat` 类型,需要根据消息不同传入 [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) 元素列表
- `model_parameters` (object) 模型参数
模型参数由模型 YAML 配置的 `parameter_rules` 定义。
- `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] 工具列表,等同于 `function calling` 中的 `function`。
- `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] 工具列表,等同于 `function calling` 中的 `function`。
即传入 tool calling 的工具列表。
- `stop` (array[string]) [optional] 停止序列
模型返回将在停止序列定义的字符串之前停止输出。
- `stream` (bool) 是否流式输出,默认 True
流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。
流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回
流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。
流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。
- 预计算输入 tokens
@ -236,7 +236,7 @@ class XinferenceProvider(Provider):
-> TextEmbeddingResult:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
@ -294,7 +294,7 @@ class XinferenceProvider(Provider):
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
@ -342,7 +342,7 @@ class XinferenceProvider(Provider):
-> str:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param file: audio file
@ -379,7 +379,7 @@ class XinferenceProvider(Provider):
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
@ -421,7 +421,7 @@ class XinferenceProvider(Provider):
-> bool:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
@ -448,11 +448,9 @@ class XinferenceProvider(Provider):
False 代表传入的文本安全True 则反之。
## 实体
### PromptMessageRole
### PromptMessageRole
消息角色
@ -623,7 +621,7 @@ class PromptMessageTool(BaseModel):
parameters: dict # 工具参数 dict
```
---
______________________________________________________________________
### LLMResult
@ -690,7 +688,7 @@ class LLMUsage(ModelUsage):
latency: float # 请求耗时 (s)
```
---
______________________________________________________________________
### TextEmbeddingResult
@ -720,7 +718,7 @@ class EmbeddingUsage(ModelUsage):
latency: float # 请求耗时 (s)
```
---
______________________________________________________________________
### RerankResult

View File

@ -62,7 +62,7 @@ pricing: # 价格信息
建议将所有模型配置都准备完毕后再开始模型代码的实现。
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。
### 实现模型调用代码
@ -82,7 +82,7 @@ pricing: # 价格信息
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
@ -137,7 +137,7 @@ pricing: # 价格信息
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
@ -153,7 +153,7 @@ pricing: # 价格信息
- `InvokeConnectionError` 调用连接错误
- `InvokeServerUnavailableError ` 调用服务方不可用
- `InvokeRateLimitError ` 调用达到限额
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeBadRequestError ` 调用传参有误
```python
@ -164,7 +164,7 @@ pricing: # 价格信息
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```

View File

@ -5,7 +5,7 @@
- `predefined-model ` 预定义模型
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
- `customizable-model` 自定义模型
用户需要新增每个模型的凭据配置,如 Xinference它同时支持 LLM 和 Text Embedding但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
@ -23,9 +23,11 @@
### 介绍
#### 名词解释
- `module`: 一个`module`即为一个 Python Package或者通俗一点称为一个文件夹里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
- `module`: 一个`module`即为一个 Python Package或者通俗一点称为一个文件夹里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
#### 步骤
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
- 创建供应商 yaml 文件,根据[ProviderSchema](./schema.md#provider)编写
@ -117,7 +119,7 @@ model_credential_schema:
en_US: Enter your API Base
```
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
#### 实现供应商代码
@ -155,12 +157,14 @@ def validate_provider_credentials(self, credentials: dict) -> None:
#### 增加模型
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
对于预定义模型,我们可以通过简单定义一个 yaml并通过实现调用代码来接入。
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。
---
______________________________________________________________________
### 测试

View File

@ -16,9 +16,9 @@
- `zh_Hans` (string) [optional] 中文描述
- `en_US` (string) 英文描述
- `icon_small` (string) [optional] 供应商小 ICON存储在对应供应商实现目录下的 `_assets` 目录,中英文策略同 `label`
- `zh_Hans` (string) [optional] 中文 ICON
- `zh_Hans` (string) [optional] 中文 ICON
- `en_US` (string) 英文 ICON
- `icon_large` (string) [optional] 供应商大 ICON存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
- `icon_large` (string) [optional] 供应商大 ICON存储在对应供应商实现目录下的 \_assets 目录,中英文策略同 label
- `zh_Hans `(string) [optional] 中文 ICON
- `en_US` (string) 英文 ICON
- `background` (string) [optional] 背景颜色色值,例:#FFFFFF,为空则展示前端默认色值。
@ -29,8 +29,8 @@
- `url` (object) 帮助链接i18n
- `zh_Hans` (string) [optional] 中文链接
- `en_US` (string) 英文链接
- `supported_model_types` (array[[ModelType](#ModelType)]) 支持的模型类型
- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) 配置方式
- `supported_model_types` (array\[[ModelType](#ModelType)\]) 支持的模型类型
- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) 配置方式
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) 供应商凭据规格
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) 模型凭据规格
@ -41,23 +41,23 @@
- `zh_Hans `(string) [optional] 中文标签名
- `en_US` (string) 英文标签名
- `model_type` ([ModelType](#ModelType)) 模型类型
- `features` (array[[ModelFeature](#ModelFeature)]) [optional] 支持功能列表
- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] 支持功能列表
- `model_properties` (object) 模型属性
- `mode` ([LLMMode](#LLMMode)) 模式 (模型类型 `llm` 可用)
- `context_size` (int) 上下文大小 (模型类型 `llm` `text-embedding` 可用)
- `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用)
- `file_upload_limit` (int) 文件最大上传限制单位MB。模型类型 `speech2text` 可用)
- `supported_file_extensions` (string) 支持文件扩展格式mp3,mp4模型类型 `speech2text` 可用)
- `default_voice` (string) 缺省音色必选alloy,echo,fable,onyx,nova,shimmer模型类型 `tts` 可用)
- `voices` (list) 可选音色列表。
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
- `language` (string) 音色模型支持语言。(模型类型 `tts` 可用)
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
- `audio_type` (string) 支持音频文件扩展格式mp3,wav模型类型 `tts` 可用)
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
- `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用)
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] 模型调用参数规则
- `supported_file_extensions` (string) 支持文件扩展格式mp3,mp4模型类型 `speech2text` 可用)
- `default_voice` (string) 缺省音色必选alloy,echo,fable,onyx,nova,shimmer模型类型 `tts` 可用)
- `voices` (list) 可选音色列表。
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
- `language` (string) 音色模型支持语言。(模型类型 `tts` 可用)
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
- `audio_type` (string) 支持音频文件扩展格式mp3,wav模型类型 `tts` 可用)
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
- `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用)
- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] 模型调用参数规则
- `pricing` ([PriceConfig](#PriceConfig)) [optional] 价格信息
- `deprecated` (bool) 是否废弃。若废弃,模型列表将不再展示,但已经配置的可以继续使用,默认 False。
@ -75,6 +75,7 @@
- `predefined-model ` 预定义模型
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
- `customizable-model` 自定义模型
用户需要新增每个模型的凭据配置。
@ -106,7 +107,7 @@
- `name` (string) 调用模型实际参数名
- `use_template` (string) [optional] 使用模板
默认预置了 5 种变量内容配置模板:
- `temperature`
@ -114,7 +115,7 @@
- `frequency_penalty`
- `presence_penalty`
- `max_tokens`
可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
不用设置除 `name``use_template` 之外的所有参数,若设置了额外的配置参数,将覆盖默认配置。
可参考 `openai/llm/gpt-3.5-turbo.yaml`
@ -157,7 +158,7 @@
### ProviderCredentialSchema
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范
### ModelCredentialSchema
@ -168,7 +169,7 @@
- `placeholder` (object) 模型提示内容
- `en_US`(string) 英文
- `zh_Hans`(string) [optional] 中文
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范
### CredentialFormSchema
@ -179,12 +180,12 @@
- `type` ([FormType](#FormType)) 表单项类型
- `required` (bool) 是否必填
- `default`(string) 默认值
- `options` (array[[FormOption](#FormOption)]) 表单项为 `select``radio` 专有属性,定义下拉内容
- `options` (array\[[FormOption](#FormOption)\]) 表单项为 `select``radio` 专有属性,定义下拉内容
- `placeholder`(object) 表单项为 `text-input `专有属性,表单项 PlaceHolder
- `en_US`(string) 英文
- `zh_Hans` (string) [optional] 中文
- `max_length` (int) 表单项为`text-input`专有属性定义输入最大长度0 为不限制。
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。
### FormType
@ -200,7 +201,7 @@
- `en_US`(string) 英文
- `zh_Hans`(string) [optional] 中文
- `value` (string) 下拉选项值
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。
### FormShowOnObject

View File

@ -98,14 +98,19 @@ class AnalyticdbVectorBySql:
try:
cur.execute(f"CREATE DATABASE {self.databaseName}")
except Exception as e:
if "already exists" in str(e):
return
raise e
if "already exists" not in str(e):
raise e
finally:
cur.close()
conn.close()
self.pool = self._create_connection_pool()
with self._get_cursor() as cur:
try:
cur.execute("CREATE EXTENSION IF NOT EXISTS zhparser;")
except Exception as e:
raise RuntimeError(
"Failed to create zhparser extension. Please ensure it is available in your AnalyticDB."
) from e
try:
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")

View File

@ -92,17 +92,21 @@ Clickzetta supports advanced full-text search with multiple analyzers:
### Analyzer Types
1. **keyword**: No tokenization, treats the entire string as a single token
- Best for: Exact matching, IDs, codes
2. **english**: Designed for English text
1. **english**: Designed for English text
- Features: Recognizes ASCII letters and numbers, converts to lowercase
- Best for: English content
3. **chinese**: Chinese text tokenizer
1. **chinese**: Chinese text tokenizer
- Features: Recognizes Chinese and English characters, removes punctuation
- Best for: Chinese or mixed Chinese-English content
4. **unicode**: Multi-language tokenizer based on Unicode
1. **unicode**: Multi-language tokenizer based on Unicode
- Features: Recognizes text boundaries in multiple languages
- Best for: Multi-language content
@ -124,21 +128,25 @@ Clickzetta supports advanced full-text search with multiple analyzers:
### Vector Search
1. **Adjust exploration factor** for accuracy vs speed trade-off:
```sql
SET cz.vector.index.search.ef=64;
```
2. **Use appropriate distance functions**:
1. **Use appropriate distance functions**:
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
- `l2_distance`: Best for raw feature vectors
### Full-Text Search
1. **Choose the right analyzer**:
- Use `keyword` for exact matching
- Use language-specific analyzers for better tokenization
2. **Combine with vector search**:
1. **Combine with vector search**:
- Pre-filter with full-text search for better performance
- Use hybrid search for improved relevance
@ -147,27 +155,30 @@ Clickzetta supports advanced full-text search with multiple analyzers:
### Connection Issues
1. Verify all 7 required configuration parameters are set
2. Check network connectivity to Clickzetta service
3. Ensure the user has proper permissions on the schema
1. Check network connectivity to Clickzetta service
1. Ensure the user has proper permissions on the schema
### Search Performance
1. Verify vector index exists:
```sql
SHOW INDEX FROM <schema>.<table_name>;
```
2. Check if vector index is being used:
1. Check if vector index is being used:
```sql
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
```
Look for `vector_index_search_type` in the execution plan.
### Full-Text Search Not Working
1. Verify inverted index is created
2. Check analyzer configuration matches your content language
3. Use `TOKENIZE()` function to test tokenization:
1. Check analyzer configuration matches your content language
1. Use `TOKENIZE()` function to test tokenization:
```sql
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
```
@ -175,13 +186,13 @@ Clickzetta supports advanced full-text search with multiple analyzers:
## Limitations
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
2. Full-text search relevance scores are not provided by Clickzetta
3. Inverted index creation may fail for very large existing tables (continue without error)
4. Index naming constraints:
1. Full-text search relevance scores are not provided by Clickzetta
1. Inverted index creation may fail for very large existing tables (continue without error)
1. Index naming constraints:
- Index names must be unique within a schema
- Only one vector index can be created per column
- The implementation uses timestamps to ensure unique index names
5. A column can only have one vector index at a time
1. A column can only have one vector index at a time
## References

View File

@ -81,14 +81,11 @@ class ApiTool(Tool):
return ToolProviderType.API
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}
if self.runtime is None:
raise ToolProviderCredentialValidationError("runtime not initialized")
headers = {}
if self.runtime is None:
raise ValueError("runtime is required")
credentials = self.runtime.credentials or {}
if "auth_type" not in credentials:
raise ToolProviderCredentialValidationError("Missing auth_type")

View File

@ -62,7 +62,7 @@ class ToolProviderApiEntity(BaseModel):
parameter.pop("input_schema", None)
# -------------
optional_fields = self.optional_field("server_url", self.server_url)
if self.type == ToolProviderType.MCP.value:
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
return {

View File

@ -959,7 +959,7 @@ class ToolManager:
elif provider_type == ToolProviderType.WORKFLOW:
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
elif provider_type == ToolProviderType.PLUGIN:
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
if isinstance(provider, PluginToolProviderController):
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)

View File

@ -1,17 +0,0 @@
import re
def get_image_upload_file_ids(content):
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
matches = re.findall(pattern, content)
image_upload_file_ids = []
for match in matches:
if match[1] == "file-preview":
content_pattern = r"files/([^/]+)/file-preview"
else:
content_pattern = r"files/([^/]+)/image-preview"
content_match = re.search(content_pattern, match[0])
if content_match:
image_upload_file_id = content_match.group(1)
image_upload_file_ids.append(image_upload_file_id)
return image_upload_file_ids

View File

@ -80,14 +80,14 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
else:
content = response.text
article = extract_using_readabilipy(content)
article = extract_using_readability(content)
if not article.text:
return ""
res = FULL_TEMPLATE.format(
title=article.title,
author=article.auther,
author=article.author,
text=article.text,
)
@ -97,15 +97,15 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
@dataclass
class Article:
title: str
auther: str
author: str
text: Sequence[dict]
def extract_using_readabilipy(html: str):
def extract_using_readability(html: str):
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
article = Article(
title=json_article.get("title") or "",
auther=json_article.get("byline") or "",
author=json_article.get("byline") or "",
text=json_article.get("plain_text") or [],
)
@ -113,7 +113,7 @@ def extract_using_readabilipy(html: str):
def get_image_upload_file_ids(content):
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
pattern = r"!\[image\]\((https?://.*?(file-preview|image-preview))\)"
matches = re.findall(pattern, content)
image_upload_file_ids = []
for match in matches:

View File

@ -203,9 +203,6 @@ class WorkflowToolProviderController(ToolProviderController):
raise ValueError("app not found")
app = db_providers.app
if not app:
raise ValueError("can not read app of workflow")
self.tools = [self._get_db_provider_tool(db_providers, app)]
return self.tools

View File

@ -123,7 +123,7 @@ class BillingService:
return BillingService._send_request("GET", "/education/verify", params=params)
@classmethod
def is_active(cls, account_id: str):
def status(cls, account_id: str):
params = {"account_id": account_id}
return BillingService._send_request("GET", "/education/status", params=params)

View File

@ -294,6 +294,11 @@ class DatasetService:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
def check_doc_form(dataset: Dataset, doc_form: str):
if dataset.doc_form and doc_form != dataset.doc_form:
raise ValueError("doc_form is different from the dataset doc_form.")
@staticmethod
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == "high_quality":
@ -1265,6 +1270,8 @@ class DocumentService:
dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web",
):
# check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
# check document limit
features = FeatureService.get_features(current_user.current_tenant_id)

View File

@ -1,18 +0,0 @@
from pydantic import BaseModel
from tasks.mail_enterprise_task import send_enterprise_email_task
class DifyMail(BaseModel):
to: list[str]
subject: str
body: str
substitutions: dict[str, str] = {}
class EnterpriseMailService:
@classmethod
def send_mail(cls, mail: DifyMail):
send_enterprise_email_task.delay(
to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions
)

View File

@ -5,7 +5,7 @@ import click
from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import (

View File

@ -6,7 +6,7 @@ import click
from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment

View File

@ -11,7 +11,7 @@ from libs.email_i18n import get_email_i18n_service
@shared_task(queue="mail")
def send_enterprise_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
def send_inner_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
if not mail.is_inited():
return

View File

@ -0,0 +1,620 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models.model import EndUser, Message
from models.web import SavedMessage
from services.app_service import AppService
from services.saved_message_service import SavedMessageService
class TestSavedMessageService:
"""Integration tests for SavedMessageService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.saved_message_service.MessageService") as mock_message_service,
):
# Setup default mock returns
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Mock ModelManager for app creation
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
# Mock MessageService
mock_message_service.get_message.return_value = None
mock_message_service.pagination_by_last_id.return_value = None
yield {
"account_feature_service": mock_account_feature_service,
"model_manager": mock_model_manager,
"message_service": mock_message_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (app, account) - Created app and account instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant first
from services.account_service import AccountService, TenantService
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
return app, account
def _create_test_end_user(self, db_session_with_containers, app):
"""
Helper method to create a test end user for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: App instance to associate the end user with
Returns:
EndUser: Created end user instance
"""
fake = Faker()
end_user = EndUser(
tenant_id=app.tenant_id,
app_id=app.id,
external_user_id=fake.uuid4(),
name=fake.name(),
type="normal",
session_id=fake.uuid4(),
is_anonymous=False,
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
return end_user
def _create_test_message(self, db_session_with_containers, app, user):
"""
Helper method to create a test message for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: App instance to associate the message with
user: User instance (Account or EndUser) to associate the message with
Returns:
Message: Created message instance
"""
fake = Faker()
# Create a simple conversation first
from models.model import Conversation
conversation = Conversation(
app_id=app.id,
from_source="account" if hasattr(user, "current_tenant") else "end_user",
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
from_account_id=user.id if hasattr(user, "current_tenant") else None,
name=fake.sentence(nb_words=3),
inputs={},
status="normal",
mode="chat",
)
from extensions.ext_database import db
db.session.add(conversation)
db.session.commit()
# Create message
message = Message(
app_id=app.id,
conversation_id=conversation.id,
from_source="account" if hasattr(user, "current_tenant") else "end_user",
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
from_account_id=user.id if hasattr(user, "current_tenant") else None,
inputs={},
query=fake.sentence(nb_words=5),
message=fake.text(max_nb_chars=100),
answer=fake.text(max_nb_chars=200),
message_tokens=50,
answer_tokens=100,
message_unit_price=0.001,
answer_unit_price=0.002,
total_price=0.003,
currency="USD",
status="success",
)
db.session.add(message)
db.session.commit()
return message
def test_pagination_by_last_id_success_with_account_user(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful pagination by last ID with account user.
This test verifies:
- Proper pagination with account user
- Correct filtering by app_id and user
- Proper role identification for account users
- MessageService integration
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create test messages
message1 = self._create_test_message(db_session_with_containers, app, account)
message2 = self._create_test_message(db_session_with_containers, app, account)
# Create saved messages
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="account",
created_by=account.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="account",
created_by=account.id,
)
from extensions.ext_database import db
db.session.add_all([saved_message1, saved_message2])
db.session.commit()
# Mock MessageService.pagination_by_last_id return value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=10, has_more=False)
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
# Act: Execute the method under test
result = SavedMessageService.pagination_by_last_id(app_model=app, user=account, last_id=None, limit=10)
# Assert: Verify the expected outcomes
assert result is not None
assert result.data == [message1, message2]
assert result.limit == 10
assert result.has_more is False
# Verify MessageService was called with correct parameters
# Sort the IDs to handle database query order variations
expected_include_ids = sorted([message1.id, message2.id])
actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args
actual_include_ids = sorted(actual_call.kwargs.get("include_ids", []))
assert actual_call.kwargs["app_model"] == app
assert actual_call.kwargs["user"] == account
assert actual_call.kwargs["last_id"] is None
assert actual_call.kwargs["limit"] == 10
assert actual_include_ids == expected_include_ids
# Verify database state
db.session.refresh(saved_message1)
db.session.refresh(saved_message2)
assert saved_message1.id is not None
assert saved_message2.id is not None
assert saved_message1.created_by_role == "account"
assert saved_message2.created_by_role == "account"
def test_pagination_by_last_id_success_with_end_user(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful pagination by last ID with end user.
This test verifies:
- Proper pagination with end user
- Correct filtering by app_id and user
- Proper role identification for end users
- MessageService integration
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
end_user = self._create_test_end_user(db_session_with_containers, app)
# Create test messages
message1 = self._create_test_message(db_session_with_containers, app, end_user)
message2 = self._create_test_message(db_session_with_containers, app, end_user)
# Create saved messages
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="end_user",
created_by=end_user.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="end_user",
created_by=end_user.id,
)
from extensions.ext_database import db
db.session.add_all([saved_message1, saved_message2])
db.session.commit()
# Mock MessageService.pagination_by_last_id return value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=5, has_more=True)
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
# Act: Execute the method under test
result = SavedMessageService.pagination_by_last_id(
app_model=app, user=end_user, last_id="test_last_id", limit=5
)
# Assert: Verify the expected outcomes
assert result is not None
assert result.data == [message1, message2]
assert result.limit == 5
assert result.has_more is True
# Verify MessageService was called with correct parameters
# Sort the IDs to handle database query order variations
expected_include_ids = sorted([message1.id, message2.id])
actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args
actual_include_ids = sorted(actual_call.kwargs.get("include_ids", []))
assert actual_call.kwargs["app_model"] == app
assert actual_call.kwargs["user"] == end_user
assert actual_call.kwargs["last_id"] == "test_last_id"
assert actual_call.kwargs["limit"] == 5
assert actual_include_ids == expected_include_ids
# Verify database state
db.session.refresh(saved_message1)
db.session.refresh(saved_message2)
assert saved_message1.id is not None
assert saved_message2.id is not None
assert saved_message1.created_by_role == "end_user"
assert saved_message2.created_by_role == "end_user"
def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful save of a new message.
This test verifies:
- Proper creation of new saved message
- Correct database state after save
- Proper relationship establishment
- MessageService integration for message retrieval
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Mock MessageService.get_message return value
mock_external_service_dependencies["message_service"].get_message.return_value = message
# Act: Execute the method under test
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was created in database
from extensions.ext_database import db
saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
)
assert saved_message is not None
assert saved_message.app_id == app.id
assert saved_message.message_id == message.id
assert saved_message.created_by_role == "account"
assert saved_message.created_by == account.id
assert saved_message.created_at is not None
# Verify MessageService.get_message was called
mock_external_service_dependencies["message_service"].get_message.assert_called_once_with(
app_model=app, user=account, message_id=message.id
)
# Verify database state
db.session.refresh(saved_message)
assert saved_message.id is not None
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when no user is provided.
This test verifies:
- Proper error handling for missing user
- ValueError is raised when user is None
- No database operations are performed
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
assert "User is required" in str(exc_info.value)
# Verify no database operations were performed
from extensions.ext_database import db
saved_messages = db.session.query(SavedMessage).all()
assert len(saved_messages) == 0
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
This test verifies:
- Method returns early when user is None
- No database operations are performed
- No exceptions are raised
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Act: Execute the method under test with None user
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
# Assert: Verify the expected outcomes
assert result is None
# Verify no saved message was created
from extensions.ext_database import db
saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
)
.first()
)
assert saved_message is None
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful deletion of an existing saved message.
This test verifies:
- Proper deletion of existing saved message
- Correct database state after deletion
- No errors during deletion process
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Create a saved message first
saved_message = SavedMessage(
app_id=app.id,
message_id=message.id,
created_by_role="account",
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(saved_message)
db.session.commit()
# Verify saved message exists
assert (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
is not None
)
# Act: Execute the method under test
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was deleted from database
deleted_saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
)
assert deleted_saved_message is None
# Verify database state
db.session.commit()
# The message should still exist, only the saved_message should be deleted
assert db.session.query(Message).where(Message.id == message.id).first() is not None
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when no user is provided.
This test verifies:
- Proper error handling for missing user
- ValueError is raised when user is None
- No database operations are performed
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
assert "User is required" in str(exc_info.value)
# Verify no database operations were performed for this specific test
# Note: We don't check total count as other tests may have created data
# Instead, we verify that the error was properly raised
pass
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
This test verifies:
- Method returns early when user is None
- No database operations are performed
- No exceptions are raised
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Act: Execute the method under test with None user
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
# Assert: Verify the expected outcomes
assert result is None
# Verify no saved message was created
from extensions.ext_database import db
saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
)
.first()
)
assert saved_message is None
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful deletion of an existing saved message.
This test verifies:
- Proper deletion of existing saved message
- Correct database state after deletion
- No errors during deletion process
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Create a saved message first
saved_message = SavedMessage(
app_id=app.id,
message_id=message.id,
created_by_role="account",
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(saved_message)
db.session.commit()
# Verify saved message exists
assert (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
is not None
)
# Act: Execute the method under test
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was deleted from database
deleted_saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
)
assert deleted_saved_message is None
# Verify database state
db.session.commit()
# The message should still exist, only the saved_message should be deleted
assert db.session.query(Message).where(Message.id == message.id).first() is not None

View File

@ -0,0 +1,25 @@
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
def test_get_image_upload_file_ids():
# should extract id from https + file-preview
content = "![image](https://example.com/a/b/files/abc123/file-preview)"
assert get_image_upload_file_ids(content) == ["abc123"]
# should extract id from http + image-preview
content = "![image](http://host/files/xyz789/image-preview)"
assert get_image_upload_file_ids(content) == ["xyz789"]
# should not match invalid scheme 'htt://'
content = "![image](htt://host/files/bad/file-preview)"
assert get_image_upload_file_ids(content) == []
# should extract multiple ids in order
content = """
some text
![image](https://h/files/id1/file-preview)
middle
![image](http://h/files/id2/image-preview)
end
"""
assert get_image_upload_file_ids(content) == ["id1", "id2"]

2
api/uv.lock generated
View File

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.11, <3.13"
resolution-markers = [
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'",