mirror of
https://github.com/langgenius/dify.git
synced 2026-03-19 13:47:37 +08:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@ -3,7 +3,7 @@
|
||||
## Usage
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
>
|
||||
> In the v1.3.0 release, `poetry` has been replaced with
|
||||
> [`uv`](https://docs.astral.sh/uv/) as the package manager
|
||||
> for Dify API backend service.
|
||||
@ -20,25 +20,29 @@
|
||||
cd ../api
|
||||
```
|
||||
|
||||
2. Copy `.env.example` to `.env`
|
||||
1. Copy `.env.example` to `.env`
|
||||
|
||||
```cli
|
||||
cp .env.example .env
|
||||
cp .env.example .env
|
||||
```
|
||||
3. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
1. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
|
||||
```bash for Linux
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
|
||||
bash for Mac
|
||||
|
||||
```bash for Mac
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
SECRET_KEY=${secret_key}" .env
|
||||
```
|
||||
|
||||
4. Create environment.
|
||||
1. Create environment.
|
||||
|
||||
Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
|
||||
First, you need to add the uv package manager, if you don't have it already.
|
||||
@ -49,13 +53,13 @@
|
||||
brew install uv
|
||||
```
|
||||
|
||||
5. Install dependencies
|
||||
1. Install dependencies
|
||||
|
||||
```bash
|
||||
uv sync --dev
|
||||
```
|
||||
|
||||
6. Run migrate
|
||||
1. Run migrate
|
||||
|
||||
Before the first launch, migrate the database to the latest version.
|
||||
|
||||
@ -63,24 +67,27 @@
|
||||
uv run flask db upgrade
|
||||
```
|
||||
|
||||
7. Start backend
|
||||
1. Start backend
|
||||
|
||||
```bash
|
||||
uv run flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
|
||||
8. Start Dify [web](../web) service.
|
||||
9. Setup your application by visiting `http://localhost:3000`.
|
||||
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
1. Start Dify [web](../web) service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
|
||||
```
|
||||
1. Setup your application by visiting `http://localhost:3000`.
|
||||
|
||||
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
|
||||
```bash
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
|
||||
```
|
||||
|
||||
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
@ -90,9 +97,8 @@
|
||||
uv sync --dev
|
||||
```
|
||||
|
||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
|
||||
```bash
|
||||
uv run -P api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@ -327,6 +329,9 @@ class EducationVerifyApi(Resource):
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
"result": fields.Boolean,
|
||||
"is_student": fields.Boolean,
|
||||
"expire_at": TimestampField,
|
||||
"allow_refresh": fields.Boolean,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@ -354,7 +359,11 @@ class EducationApi(Resource):
|
||||
def get(self):
|
||||
account = current_user
|
||||
|
||||
return BillingService.EducationIdentity.is_active(account.id)
|
||||
res = BillingService.EducationIdentity.status(account.id)
|
||||
# convert expire_at to UTC timestamp from isoformat
|
||||
if res and "expire_at" in res:
|
||||
res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc)
|
||||
return res
|
||||
|
||||
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
|
||||
@ -1,27 +1,38 @@
|
||||
from flask_restful import (
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from services.enterprise.mail_service import DifyMail, EnterpriseMailService
|
||||
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
|
||||
from tasks.mail_inner_task import send_inner_email_task
|
||||
|
||||
_mail_parser = reqparse.RequestParser()
|
||||
_mail_parser.add_argument("to", type=str, action="append", required=True)
|
||||
_mail_parser.add_argument("subject", type=str, required=True)
|
||||
_mail_parser.add_argument("body", type=str, required=True)
|
||||
_mail_parser.add_argument("substitutions", type=dict, required=False)
|
||||
|
||||
|
||||
class EnterpriseMail(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
class BaseMail(Resource):
|
||||
"""Shared logic for sending an inner email."""
|
||||
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("to", type=str, action="append", required=True)
|
||||
parser.add_argument("subject", type=str, required=True)
|
||||
parser.add_argument("body", type=str, required=True)
|
||||
parser.add_argument("substitutions", type=dict, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
EnterpriseMailService.send_mail(DifyMail(**args))
|
||||
args = _mail_parser.parse_args()
|
||||
send_inner_email_task.delay(
|
||||
to=args["to"],
|
||||
subject=args["subject"],
|
||||
body=args["body"],
|
||||
substitutions=args["substitutions"],
|
||||
)
|
||||
return {"message": "success"}, 200
|
||||
|
||||
|
||||
class EnterpriseMail(BaseMail):
|
||||
method_decorators = [setup_required, enterprise_inner_api_only]
|
||||
|
||||
|
||||
class BillingMail(BaseMail):
|
||||
method_decorators = [setup_required, billing_inner_api_only]
|
||||
|
||||
|
||||
api.add_resource(EnterpriseMail, "/enterprise/mail")
|
||||
api.add_resource(BillingMail, "/billing/mail")
|
||||
|
||||
@ -10,6 +10,22 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||
abort(401)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
||||
@ -30,7 +30,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@ -532,7 +532,7 @@ class LLMGenerator:
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
match node_type:
|
||||
case "llm", "agent":
|
||||
case "llm" | "agent":
|
||||
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
|
||||
case "code":
|
||||
system_prompt = LLM_MODIFY_CODE_SYSTEM
|
||||
|
||||
@ -7,7 +7,7 @@ import urllib.parse
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
@ -105,18 +105,18 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
|
||||
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
response = requests.get(url, headers=headers)
|
||||
response = httpx.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except requests.RequestException as e:
|
||||
if isinstance(e, requests.ConnectionError):
|
||||
response = requests.get(url)
|
||||
except httpx.RequestError as e:
|
||||
if isinstance(e, httpx.ConnectError):
|
||||
response = httpx.get(url)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
raise
|
||||
@ -206,8 +206,8 @@ def exchange_authorization(
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
response = httpx.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
@ -237,8 +237,8 @@ def refresh_authorization(
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
response = httpx.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
@ -256,12 +256,12 @@ def register_client(
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = requests.post(
|
||||
response = httpx.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
if not response.ok:
|
||||
if not response.is_success:
|
||||
response.raise_for_status()
|
||||
return OAuthClientInformationFull.model_validate(response.json())
|
||||
|
||||
@ -283,7 +283,7 @@ def auth(
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
try:
|
||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
||||
except requests.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
provider.save_client_information(full_information)
|
||||
client_information = full_information
|
||||
|
||||
@ -30,7 +30,7 @@ This module provides the interface for invoking and authenticating various model
|
||||
|
||||
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
|
||||
|
||||

|
||||

|
||||
|
||||
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
|
||||
|
||||
@ -60,8 +60,6 @@ Model Runtime is divided into three layers:
|
||||
|
||||
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
|
||||
|
||||
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
|
||||
|
||||
@ -20,19 +20,19 @@
|
||||
|
||||

|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||

|
||||
|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
|
||||
|
||||

|
||||

|
||||
|
||||
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
|
||||
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
|
||||
|
||||
- 供应商/模型凭据鉴权
|
||||
|
||||
@ -40,7 +40,7 @@
|
||||
|
||||

|
||||
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。
|
||||
|
||||
## 结构
|
||||
|
||||
@ -57,9 +57,10 @@ Model Runtime 分三层:
|
||||
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
|
||||
|
||||
对于供应商/模型凭据,有两种情况
|
||||
|
||||
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||

|
||||

|
||||
|
||||
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
||||
|
||||
@ -76,14 +77,17 @@ Model Runtime 分三层:
|
||||
## 下一步
|
||||
|
||||
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
|
||||
|
||||
当添加后,这里将会出现一个新的供应商
|
||||
|
||||

|
||||
|
||||
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
|
||||
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
|
||||
|
||||
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
|
||||
|
||||

|
||||
|
||||
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
|
||||
|
||||
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
||||
|
||||
@ -56,7 +56,6 @@ provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
```
|
||||
|
||||
|
||||
Then, we need to determine what credentials are required to define a model in Xinference.
|
||||
|
||||
- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
|
||||
@ -191,7 +190,6 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr
|
||||
"""
|
||||
```
|
||||
|
||||
|
||||
Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens and ensure environment variable `PLUGIN_BASED_TOKEN_COUNTING_ENABLED` is set to `true`, This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
|
||||
|
||||
- Model Credentials Validation
|
||||
|
||||
@ -35,12 +35,11 @@ All models need to uniformly implement the following 2 methods:
|
||||
|
||||
Similar to provider credential verification, this step involves verification for an individual model.
|
||||
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
@ -77,12 +76,12 @@ All models need to uniformly implement the following 2 methods:
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
You can refer to OpenAI's `_invoke_error_mapping` for an example.
|
||||
You can refer to OpenAI's `_invoke_error_mapping` for an example.
|
||||
|
||||
### LLM
|
||||
|
||||
@ -92,7 +91,6 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
|
||||
|
||||
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
|
||||
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
@ -101,7 +99,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
@ -122,7 +120,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `prompt_messages` (array[[PromptMessage](#PromptMessage)]) List of prompts
|
||||
- `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) List of prompts
|
||||
|
||||
If the model is of the `Completion` type, the list only needs to include one [UserPromptMessage](#UserPromptMessage) element;
|
||||
|
||||
@ -132,7 +130,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
|
||||
|
||||
The model parameters are defined by the `parameter_rules` in the model's YAML configuration.
|
||||
|
||||
- `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] List of tools, equivalent to the `function` in `function calling`.
|
||||
- `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] List of tools, equivalent to the `function` in `function calling`.
|
||||
|
||||
That is, the tool list for tool calling.
|
||||
|
||||
@ -142,7 +140,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
|
||||
|
||||
- `stream` (bool) Whether to output in a streaming manner, default is True
|
||||
|
||||
Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult).
|
||||
Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult).
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
@ -150,7 +148,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
|
||||
|
||||
- Returns
|
||||
|
||||
Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult).
|
||||
Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult).
|
||||
|
||||
- Pre-calculating Input Tokens
|
||||
|
||||
@ -187,7 +185,6 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
|
||||
|
||||
When the provider supports adding custom LLMs, this method can be implemented to allow custom models to fetch model schema. The default return null.
|
||||
|
||||
|
||||
### TextEmbedding
|
||||
|
||||
Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and implement the following interfaces:
|
||||
@ -200,7 +197,7 @@ Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and impl
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
@ -256,7 +253,7 @@ Inherit the `__base.rerank_model.RerankModel` base class and implement the follo
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
@ -302,7 +299,7 @@ Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
@ -339,7 +336,7 @@ Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
@ -381,7 +378,7 @@ Inherit the `__base.moderation_model.ModerationModel` base class and implement t
|
||||
-> bool:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
@ -408,11 +405,9 @@ Inherit the `__base.moderation_model.ModerationModel` base class and implement t
|
||||
|
||||
False indicates that the input text is safe, True indicates otherwise.
|
||||
|
||||
|
||||
|
||||
## Entities
|
||||
|
||||
### PromptMessageRole
|
||||
### PromptMessageRole
|
||||
|
||||
Message role
|
||||
|
||||
@ -583,7 +578,7 @@ class PromptMessageTool(BaseModel):
|
||||
parameters: dict
|
||||
```
|
||||
|
||||
---
|
||||
______________________________________________________________________
|
||||
|
||||
### LLMResult
|
||||
|
||||
@ -650,7 +645,7 @@ class LLMUsage(ModelUsage):
|
||||
latency: float # Request latency (s)
|
||||
```
|
||||
|
||||
---
|
||||
______________________________________________________________________
|
||||
|
||||
### TextEmbeddingResult
|
||||
|
||||
@ -680,7 +675,7 @@ class EmbeddingUsage(ModelUsage):
|
||||
latency: float # Request latency (s)
|
||||
```
|
||||
|
||||
---
|
||||
______________________________________________________________________
|
||||
|
||||
### RerankResult
|
||||
|
||||
|
||||
@ -153,8 +153,11 @@ Runtime Errors:
|
||||
- `InvokeConnectionError` Connection error
|
||||
|
||||
- `InvokeServerUnavailableError` Service provider unavailable
|
||||
|
||||
- `InvokeRateLimitError` Rate limit reached
|
||||
|
||||
- `InvokeAuthorizationError` Authorization failed
|
||||
|
||||
- `InvokeBadRequestError` Parameter error
|
||||
|
||||
```python
|
||||
|
||||
@ -63,6 +63,7 @@ You can also refer to the YAML configuration information under other provider di
|
||||
### Implementing Provider Code
|
||||
|
||||
Providers need to inherit the `__base.model_provider.ModelProvider` base class and implement the `validate_provider_credentials` method for unified provider credential verification. For reference, see [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py).
|
||||
|
||||
> If the provider is the type of `customizable-model`, there is no need to implement the `validate_provider_credentials` method.
|
||||
|
||||
```python
|
||||
@ -80,7 +81,7 @@ def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
|
||||
Of course, you can also preliminarily reserve the implementation of `validate_provider_credentials` and directly reuse it after the model credential verification method is implemented.
|
||||
|
||||
---
|
||||
______________________________________________________________________
|
||||
|
||||
### Adding Models
|
||||
|
||||
@ -166,7 +167,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
@ -205,7 +206,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
@ -232,7 +233,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
@ -28,8 +28,8 @@
|
||||
- `url` (object) help link, i18n
|
||||
- `zh_Hans` (string) [optional] Chinese link
|
||||
- `en_US` (string) English link
|
||||
- `supported_model_types` (array[[ModelType](#ModelType)]) Supported model types
|
||||
- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) Configuration methods
|
||||
- `supported_model_types` (array\[[ModelType](#ModelType)\]) Supported model types
|
||||
- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) Configuration methods
|
||||
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) Provider credential specification
|
||||
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) Model credential specification
|
||||
|
||||
@ -40,23 +40,23 @@
|
||||
- `zh_Hans` (string) [optional] Chinese label name
|
||||
- `en_US` (string) English label name
|
||||
- `model_type` ([ModelType](#ModelType)) Model type
|
||||
- `features` (array[[ModelFeature](#ModelFeature)]) [optional] Supported feature list
|
||||
- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] Supported feature list
|
||||
- `model_properties` (object) Model properties
|
||||
- `mode` ([LLMMode](#LLMMode)) Mode (available for model type `llm`)
|
||||
- `context_size` (int) Context size (available for model types `llm`, `text-embedding`)
|
||||
- `max_chunks` (int) Maximum number of chunks (available for model types `text-embedding`, `moderation`)
|
||||
- `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`)
|
||||
- `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`)
|
||||
- `default_voice` (string) default voice, e.g.:alloy,echo,fable,onyx,nova,shimmer(available for model type `tts`)
|
||||
- `voices` (list) List of available voice.(available for model type `tts`)
|
||||
- `mode` (string) voice model.(available for model type `tts`)
|
||||
- `name` (string) voice model display name.(available for model type `tts`)
|
||||
- `language` (string) the voice model supports languages.(available for model type `tts`)
|
||||
- `word_limit` (int) Single conversion word limit, paragraph-wise by default(available for model type `tts`)
|
||||
- `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`)
|
||||
- `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`)
|
||||
- `default_voice` (string) default voice, e.g.:alloy,echo,fable,onyx,nova,shimmer(available for model type `tts`)
|
||||
- `voices` (list) List of available voice.(available for model type `tts`)
|
||||
- `mode` (string) voice model.(available for model type `tts`)
|
||||
- `name` (string) voice model display name.(available for model type `tts`)
|
||||
- `language` (string) the voice model supports languages.(available for model type `tts`)
|
||||
- `word_limit` (int) Single conversion word limit, paragraph-wise by default(available for model type `tts`)
|
||||
- `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`)
|
||||
- `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`)
|
||||
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
|
||||
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] Model invocation parameter rules
|
||||
- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] Model invocation parameter rules
|
||||
- `pricing` ([PriceConfig](#PriceConfig)) [optional] Pricing information
|
||||
- `deprecated` (bool) Whether deprecated. If deprecated, the model will no longer be displayed in the list, but those already configured can continue to be used. Default False.
|
||||
|
||||
@ -74,6 +74,7 @@
|
||||
- `predefined-model` Predefined model
|
||||
|
||||
Indicates that users can use the predefined models under the provider by configuring the unified provider credentials.
|
||||
|
||||
- `customizable-model` Customizable model
|
||||
|
||||
Users need to add credential configuration for each model.
|
||||
@ -103,6 +104,7 @@
|
||||
### ParameterRule
|
||||
|
||||
- `name` (string) Actual model invocation parameter name
|
||||
|
||||
- `use_template` (string) [optional] Using template
|
||||
|
||||
By default, 5 variable content configuration templates are preset:
|
||||
@ -112,7 +114,7 @@
|
||||
- `frequency_penalty`
|
||||
- `presence_penalty`
|
||||
- `max_tokens`
|
||||
|
||||
|
||||
In use_template, you can directly set the template variable name, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
|
||||
No need to set any parameters other than `name` and `use_template`. If additional configuration parameters are set, they will override the default configuration.
|
||||
Refer to `openai/llm/gpt-3.5-turbo.yaml`.
|
||||
@ -155,7 +157,7 @@
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
||||
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard
|
||||
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard
|
||||
|
||||
### ModelCredentialSchema
|
||||
|
||||
@ -166,7 +168,7 @@
|
||||
- `placeholder` (object) Model prompt content
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans`(string) [optional] Chinese
|
||||
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard
|
||||
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard
|
||||
|
||||
### CredentialFormSchema
|
||||
|
||||
@ -177,12 +179,12 @@
|
||||
- `type` ([FormType](#FormType)) Form item type
|
||||
- `required` (bool) Whether required
|
||||
- `default`(string) Default value
|
||||
- `options` (array[[FormOption](#FormOption)]) Specific property of form items of type `select` or `radio`, defining dropdown content
|
||||
- `options` (array\[[FormOption](#FormOption)\]) Specific property of form items of type `select` or `radio`, defining dropdown content
|
||||
- `placeholder`(object) Specific property of form items of type `text-input`, placeholder content
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans` (string) [optional] Chinese
|
||||
- `max_length` (int) Specific property of form items of type `text-input`, defining maximum input length, 0 for no limit.
|
||||
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty.
|
||||
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty.
|
||||
|
||||
### FormType
|
||||
|
||||
@ -198,7 +200,7 @@
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans`(string) [optional] Chinese
|
||||
- `value` (string) Dropdown option value
|
||||
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty.
|
||||
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty.
|
||||
|
||||
### FormShowOnObject
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@
|
||||
|
||||

|
||||
|
||||
|
||||
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`,Runtime 会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
|
||||
|
||||
### 编写供应商 yaml
|
||||
@ -55,6 +54,7 @@ provider_credential_schema:
|
||||
随后,我们需要思考在 Xinference 中定义一个模型需要哪些凭据
|
||||
|
||||
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
|
||||
|
||||
```yaml
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
@ -76,7 +76,9 @@ provider_credential_schema:
|
||||
label:
|
||||
en_US: Rerank
|
||||
```
|
||||
|
||||
- 每一个模型都有自己的名称`model_name`,因此需要在这里定义
|
||||
|
||||
```yaml
|
||||
- variable: model_name
|
||||
type: text-input
|
||||
@ -88,7 +90,9 @@ provider_credential_schema:
|
||||
zh_Hans: 填写模型名称
|
||||
en_US: Input model name
|
||||
```
|
||||
|
||||
- 填写 Xinference 本地部署的地址
|
||||
|
||||
```yaml
|
||||
- variable: server_url
|
||||
label:
|
||||
@ -100,7 +104,9 @@ provider_credential_schema:
|
||||
zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
|
||||
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
||||
```
|
||||
|
||||
- 每个模型都有唯一的 model_uid,因此需要在这里定义
|
||||
|
||||
```yaml
|
||||
- variable: model_uid
|
||||
label:
|
||||
@ -112,6 +118,7 @@ provider_credential_schema:
|
||||
zh_Hans: 在此输入您的 Model UID
|
||||
en_US: Enter the model uid
|
||||
```
|
||||
|
||||
现在,我们就完成了供应商的基础定义。
|
||||
|
||||
### 编写模型代码
|
||||
@ -132,7 +139,7 @@ provider_credential_schema:
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
@ -189,7 +196,7 @@ provider_credential_schema:
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
@ -197,78 +204,78 @@ provider_credential_schema:
|
||||
```
|
||||
|
||||
- 模型参数 Schema
|
||||
|
||||
|
||||
与自定义类型不同,由于没有在 yaml 文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的 Schema。
|
||||
|
||||
|
||||
如 Xinference 支持`max_tokens` `temperature` `top_p` 这三个模型参数。
|
||||
|
||||
|
||||
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例 A 模型支持`top_k`,B 模型不支持`top_k`,那么我们需要在这里动态生成模型参数的 Schema,如下所示:
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature', type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度', en_US='Temperature'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p', type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P', en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens', type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度', en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# if model is A, add top_k to rules
|
||||
if model == 'A':
|
||||
rules.append(
|
||||
ParameterRule(
|
||||
name='top_k', type=ParameterType.INT,
|
||||
use_template='top_k',
|
||||
min=1,
|
||||
default=50,
|
||||
label=I18nObject(
|
||||
zh_Hans='Top K', en_US='Top K'
|
||||
)
|
||||
)
|
||||
)
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature', type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度', en_US='Temperature'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p', type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P', en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens', type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度', en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
"""
|
||||
some NOT IMPORTANT code here
|
||||
"""
|
||||
# if model is A, add top_k to rules
|
||||
if model == 'A':
|
||||
rules.append(
|
||||
ParameterRule(
|
||||
name='top_k', type=ParameterType.INT,
|
||||
use_template='top_k',
|
||||
min=1,
|
||||
default=50,
|
||||
label=I18nObject(
|
||||
zh_Hans='Top K', en_US='Top K'
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=model_type,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: ModelType.LLM,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
"""
|
||||
some NOT IMPORTANT code here
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=model_type,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: ModelType.LLM,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
return entity
|
||||
```
|
||||
|
||||
return entity
|
||||
```
|
||||
|
||||
- 调用异常错误映射表
|
||||
|
||||
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
|
||||
@ -278,7 +285,7 @@ provider_credential_schema:
|
||||
- `InvokeConnectionError` 调用连接错误
|
||||
- `InvokeServerUnavailableError ` 调用服务方不可用
|
||||
- `InvokeRateLimitError ` 调用达到限额
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeBadRequestError ` 调用传参有误
|
||||
|
||||
```python
|
||||
@ -289,7 +296,7 @@ provider_credential_schema:
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
@ -49,7 +49,7 @@ class XinferenceProvider(Provider):
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
@ -75,7 +75,7 @@ class XinferenceProvider(Provider):
|
||||
- `InvokeConnectionError` 调用连接错误
|
||||
- `InvokeServerUnavailableError ` 调用服务方不可用
|
||||
- `InvokeRateLimitError ` 调用达到限额
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeBadRequestError ` 调用传参有误
|
||||
|
||||
```python
|
||||
@ -86,36 +86,36 @@ class XinferenceProvider(Provider):
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
也可以直接抛出对应 Errors,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
可参考 OpenAI `_invoke_error_mapping`。
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
可参考 OpenAI `_invoke_error_mapping`。
|
||||
|
||||
### LLM
|
||||
|
||||
@ -133,7 +133,7 @@ class XinferenceProvider(Provider):
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
@ -151,38 +151,38 @@ class XinferenceProvider(Provider):
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `prompt_messages` (array[[PromptMessage](#PromptMessage)]) Prompt 列表
|
||||
|
||||
- `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) Prompt 列表
|
||||
|
||||
若模型为 `Completion` 类型,则列表只需要传入一个 [UserPromptMessage](#UserPromptMessage) 元素即可;
|
||||
|
||||
|
||||
若模型为 `Chat` 类型,需要根据消息不同传入 [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) 元素列表
|
||||
|
||||
- `model_parameters` (object) 模型参数
|
||||
|
||||
|
||||
模型参数由模型 YAML 配置的 `parameter_rules` 定义。
|
||||
|
||||
- `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] 工具列表,等同于 `function calling` 中的 `function`。
|
||||
|
||||
- `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] 工具列表,等同于 `function calling` 中的 `function`。
|
||||
|
||||
即传入 tool calling 的工具列表。
|
||||
|
||||
- `stop` (array[string]) [optional] 停止序列
|
||||
|
||||
|
||||
模型返回将在停止序列定义的字符串之前停止输出。
|
||||
|
||||
- `stream` (bool) 是否流式输出,默认 True
|
||||
|
||||
流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。
|
||||
|
||||
流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回
|
||||
|
||||
流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。
|
||||
流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。
|
||||
|
||||
- 预计算输入 tokens
|
||||
|
||||
@ -236,7 +236,7 @@ class XinferenceProvider(Provider):
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
@ -294,7 +294,7 @@ class XinferenceProvider(Provider):
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
@ -342,7 +342,7 @@ class XinferenceProvider(Provider):
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
@ -379,7 +379,7 @@ class XinferenceProvider(Provider):
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
@ -421,7 +421,7 @@ class XinferenceProvider(Provider):
|
||||
-> bool:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
@ -448,11 +448,9 @@ class XinferenceProvider(Provider):
|
||||
|
||||
False 代表传入的文本安全,True 则反之。
|
||||
|
||||
|
||||
|
||||
## 实体
|
||||
|
||||
### PromptMessageRole
|
||||
### PromptMessageRole
|
||||
|
||||
消息角色
|
||||
|
||||
@ -623,7 +621,7 @@ class PromptMessageTool(BaseModel):
|
||||
parameters: dict # 工具参数 dict
|
||||
```
|
||||
|
||||
---
|
||||
______________________________________________________________________
|
||||
|
||||
### LLMResult
|
||||
|
||||
@ -690,7 +688,7 @@ class LLMUsage(ModelUsage):
|
||||
latency: float # 请求耗时 (s)
|
||||
```
|
||||
|
||||
---
|
||||
______________________________________________________________________
|
||||
|
||||
### TextEmbeddingResult
|
||||
|
||||
@ -720,7 +718,7 @@ class EmbeddingUsage(ModelUsage):
|
||||
latency: float # 请求耗时 (s)
|
||||
```
|
||||
|
||||
---
|
||||
______________________________________________________________________
|
||||
|
||||
### RerankResult
|
||||
|
||||
|
||||
@ -62,7 +62,7 @@ pricing: # 价格信息
|
||||
|
||||
建议将所有模型配置都准备完毕后再开始模型代码的实现。
|
||||
|
||||
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。
|
||||
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。
|
||||
|
||||
### 实现模型调用代码
|
||||
|
||||
@ -82,7 +82,7 @@ pricing: # 价格信息
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
@ -137,7 +137,7 @@ pricing: # 价格信息
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
@ -153,7 +153,7 @@ pricing: # 价格信息
|
||||
- `InvokeConnectionError` 调用连接错误
|
||||
- `InvokeServerUnavailableError ` 调用服务方不可用
|
||||
- `InvokeRateLimitError ` 调用达到限额
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeBadRequestError ` 调用传参有误
|
||||
|
||||
```python
|
||||
@ -164,7 +164,7 @@ pricing: # 价格信息
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
- `predefined-model ` 预定义模型
|
||||
|
||||
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
|
||||
|
||||
|
||||
- `customizable-model` 自定义模型
|
||||
|
||||
用户需要新增每个模型的凭据配置,如 Xinference,它同时支持 LLM 和 Text Embedding,但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
|
||||
@ -23,9 +23,11 @@
|
||||
### 介绍
|
||||
|
||||
#### 名词解释
|
||||
- `module`: 一个`module`即为一个 Python Package,或者通俗一点,称为一个文件夹,里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
|
||||
|
||||
- `module`: 一个`module`即为一个 Python Package,或者通俗一点,称为一个文件夹,里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
|
||||
|
||||
#### 步骤
|
||||
|
||||
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
|
||||
|
||||
- 创建供应商 yaml 文件,根据[ProviderSchema](./schema.md#provider)编写
|
||||
@ -117,7 +119,7 @@ model_credential_schema:
|
||||
en_US: Enter your API Base
|
||||
```
|
||||
|
||||
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
|
||||
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
|
||||
|
||||
#### 实现供应商代码
|
||||
|
||||
@ -155,12 +157,14 @@ def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
#### 增加模型
|
||||
|
||||
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
|
||||
|
||||
对于预定义模型,我们可以通过简单定义一个 yaml,并通过实现调用代码来接入。
|
||||
|
||||
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
|
||||
|
||||
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。
|
||||
|
||||
---
|
||||
______________________________________________________________________
|
||||
|
||||
### 测试
|
||||
|
||||
|
||||
@ -16,9 +16,9 @@
|
||||
- `zh_Hans` (string) [optional] 中文描述
|
||||
- `en_US` (string) 英文描述
|
||||
- `icon_small` (string) [optional] 供应商小 ICON,存储在对应供应商实现目录下的 `_assets` 目录,中英文策略同 `label`
|
||||
- `zh_Hans` (string) [optional] 中文 ICON
|
||||
- `zh_Hans` (string) [optional] 中文 ICON
|
||||
- `en_US` (string) 英文 ICON
|
||||
- `icon_large` (string) [optional] 供应商大 ICON,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
|
||||
- `icon_large` (string) [optional] 供应商大 ICON,存储在对应供应商实现目录下的 \_assets 目录,中英文策略同 label
|
||||
- `zh_Hans `(string) [optional] 中文 ICON
|
||||
- `en_US` (string) 英文 ICON
|
||||
- `background` (string) [optional] 背景颜色色值,例:#FFFFFF,为空则展示前端默认色值。
|
||||
@ -29,8 +29,8 @@
|
||||
- `url` (object) 帮助链接,i18n
|
||||
- `zh_Hans` (string) [optional] 中文链接
|
||||
- `en_US` (string) 英文链接
|
||||
- `supported_model_types` (array[[ModelType](#ModelType)]) 支持的模型类型
|
||||
- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) 配置方式
|
||||
- `supported_model_types` (array\[[ModelType](#ModelType)\]) 支持的模型类型
|
||||
- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) 配置方式
|
||||
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) 供应商凭据规格
|
||||
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) 模型凭据规格
|
||||
|
||||
@ -41,23 +41,23 @@
|
||||
- `zh_Hans `(string) [optional] 中文标签名
|
||||
- `en_US` (string) 英文标签名
|
||||
- `model_type` ([ModelType](#ModelType)) 模型类型
|
||||
- `features` (array[[ModelFeature](#ModelFeature)]) [optional] 支持功能列表
|
||||
- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] 支持功能列表
|
||||
- `model_properties` (object) 模型属性
|
||||
- `mode` ([LLMMode](#LLMMode)) 模式 (模型类型 `llm` 可用)
|
||||
- `context_size` (int) 上下文大小 (模型类型 `llm` `text-embedding` 可用)
|
||||
- `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用)
|
||||
- `file_upload_limit` (int) 文件最大上传限制,单位:MB。(模型类型 `speech2text` 可用)
|
||||
- `supported_file_extensions` (string) 支持文件扩展格式,如:mp3,mp4(模型类型 `speech2text` 可用)
|
||||
- `default_voice` (string) 缺省音色,必选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用)
|
||||
- `voices` (list) 可选音色列表。
|
||||
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
|
||||
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
|
||||
- `language` (string) 音色模型支持语言。(模型类型 `tts` 可用)
|
||||
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
|
||||
- `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用)
|
||||
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
|
||||
- `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用)
|
||||
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] 模型调用参数规则
|
||||
- `supported_file_extensions` (string) 支持文件扩展格式,如:mp3,mp4(模型类型 `speech2text` 可用)
|
||||
- `default_voice` (string) 缺省音色,必选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用)
|
||||
- `voices` (list) 可选音色列表。
|
||||
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
|
||||
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
|
||||
- `language` (string) 音色模型支持语言。(模型类型 `tts` 可用)
|
||||
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
|
||||
- `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用)
|
||||
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
|
||||
- `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用)
|
||||
- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] 模型调用参数规则
|
||||
- `pricing` ([PriceConfig](#PriceConfig)) [optional] 价格信息
|
||||
- `deprecated` (bool) 是否废弃。若废弃,模型列表将不再展示,但已经配置的可以继续使用,默认 False。
|
||||
|
||||
@ -75,6 +75,7 @@
|
||||
- `predefined-model ` 预定义模型
|
||||
|
||||
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
|
||||
|
||||
- `customizable-model` 自定义模型
|
||||
|
||||
用户需要新增每个模型的凭据配置。
|
||||
@ -106,7 +107,7 @@
|
||||
- `name` (string) 调用模型实际参数名
|
||||
|
||||
- `use_template` (string) [optional] 使用模板
|
||||
|
||||
|
||||
默认预置了 5 种变量内容配置模板:
|
||||
|
||||
- `temperature`
|
||||
@ -114,7 +115,7 @@
|
||||
- `frequency_penalty`
|
||||
- `presence_penalty`
|
||||
- `max_tokens`
|
||||
|
||||
|
||||
可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
|
||||
不用设置除 `name` 和 `use_template` 之外的所有参数,若设置了额外的配置参数,将覆盖默认配置。
|
||||
可参考 `openai/llm/gpt-3.5-turbo.yaml`。
|
||||
@ -157,7 +158,7 @@
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
||||
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范
|
||||
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范
|
||||
|
||||
### ModelCredentialSchema
|
||||
|
||||
@ -168,7 +169,7 @@
|
||||
- `placeholder` (object) 模型提示内容
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans`(string) [optional] 中文
|
||||
- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范
|
||||
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范
|
||||
|
||||
### CredentialFormSchema
|
||||
|
||||
@ -179,12 +180,12 @@
|
||||
- `type` ([FormType](#FormType)) 表单项类型
|
||||
- `required` (bool) 是否必填
|
||||
- `default`(string) 默认值
|
||||
- `options` (array[[FormOption](#FormOption)]) 表单项为 `select` 或 `radio` 专有属性,定义下拉内容
|
||||
- `options` (array\[[FormOption](#FormOption)\]) 表单项为 `select` 或 `radio` 专有属性,定义下拉内容
|
||||
- `placeholder`(object) 表单项为 `text-input `专有属性,表单项 PlaceHolder
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans` (string) [optional] 中文
|
||||
- `max_length` (int) 表单项为`text-input`专有属性,定义输入最大长度,0 为不限制。
|
||||
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。
|
||||
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。
|
||||
|
||||
### FormType
|
||||
|
||||
@ -200,7 +201,7 @@
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans`(string) [optional] 中文
|
||||
- `value` (string) 下拉选项值
|
||||
- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。
|
||||
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。
|
||||
|
||||
### FormShowOnObject
|
||||
|
||||
|
||||
@ -98,14 +98,19 @@ class AnalyticdbVectorBySql:
|
||||
try:
|
||||
cur.execute(f"CREATE DATABASE {self.databaseName}")
|
||||
except Exception as e:
|
||||
if "already exists" in str(e):
|
||||
return
|
||||
raise e
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
self.pool = self._create_connection_pool()
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS zhparser;")
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Failed to create zhparser extension. Please ensure it is available in your AnalyticDB."
|
||||
) from e
|
||||
try:
|
||||
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
|
||||
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
|
||||
|
||||
@ -92,17 +92,21 @@ Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
### Analyzer Types
|
||||
|
||||
1. **keyword**: No tokenization, treats the entire string as a single token
|
||||
|
||||
- Best for: Exact matching, IDs, codes
|
||||
|
||||
2. **english**: Designed for English text
|
||||
1. **english**: Designed for English text
|
||||
|
||||
- Features: Recognizes ASCII letters and numbers, converts to lowercase
|
||||
- Best for: English content
|
||||
|
||||
3. **chinese**: Chinese text tokenizer
|
||||
1. **chinese**: Chinese text tokenizer
|
||||
|
||||
- Features: Recognizes Chinese and English characters, removes punctuation
|
||||
- Best for: Chinese or mixed Chinese-English content
|
||||
|
||||
4. **unicode**: Multi-language tokenizer based on Unicode
|
||||
1. **unicode**: Multi-language tokenizer based on Unicode
|
||||
|
||||
- Features: Recognizes text boundaries in multiple languages
|
||||
- Best for: Multi-language content
|
||||
|
||||
@ -124,21 +128,25 @@ Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
### Vector Search
|
||||
|
||||
1. **Adjust exploration factor** for accuracy vs speed trade-off:
|
||||
|
||||
```sql
|
||||
SET cz.vector.index.search.ef=64;
|
||||
```
|
||||
|
||||
2. **Use appropriate distance functions**:
|
||||
1. **Use appropriate distance functions**:
|
||||
|
||||
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
|
||||
- `l2_distance`: Best for raw feature vectors
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
1. **Choose the right analyzer**:
|
||||
|
||||
- Use `keyword` for exact matching
|
||||
- Use language-specific analyzers for better tokenization
|
||||
|
||||
2. **Combine with vector search**:
|
||||
1. **Combine with vector search**:
|
||||
|
||||
- Pre-filter with full-text search for better performance
|
||||
- Use hybrid search for improved relevance
|
||||
|
||||
@ -147,27 +155,30 @@ Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
### Connection Issues
|
||||
|
||||
1. Verify all 7 required configuration parameters are set
|
||||
2. Check network connectivity to Clickzetta service
|
||||
3. Ensure the user has proper permissions on the schema
|
||||
1. Check network connectivity to Clickzetta service
|
||||
1. Ensure the user has proper permissions on the schema
|
||||
|
||||
### Search Performance
|
||||
|
||||
1. Verify vector index exists:
|
||||
|
||||
```sql
|
||||
SHOW INDEX FROM <schema>.<table_name>;
|
||||
```
|
||||
|
||||
2. Check if vector index is being used:
|
||||
1. Check if vector index is being used:
|
||||
|
||||
```sql
|
||||
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
|
||||
```
|
||||
|
||||
Look for `vector_index_search_type` in the execution plan.
|
||||
|
||||
### Full-Text Search Not Working
|
||||
|
||||
1. Verify inverted index is created
|
||||
2. Check analyzer configuration matches your content language
|
||||
3. Use `TOKENIZE()` function to test tokenization:
|
||||
1. Check analyzer configuration matches your content language
|
||||
1. Use `TOKENIZE()` function to test tokenization:
|
||||
```sql
|
||||
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
|
||||
```
|
||||
@ -175,13 +186,13 @@ Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
## Limitations
|
||||
|
||||
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
|
||||
2. Full-text search relevance scores are not provided by Clickzetta
|
||||
3. Inverted index creation may fail for very large existing tables (continue without error)
|
||||
4. Index naming constraints:
|
||||
1. Full-text search relevance scores are not provided by Clickzetta
|
||||
1. Inverted index creation may fail for very large existing tables (continue without error)
|
||||
1. Index naming constraints:
|
||||
- Index names must be unique within a schema
|
||||
- Only one vector index can be created per column
|
||||
- The implementation uses timestamps to ensure unique index names
|
||||
5. A column can only have one vector index at a time
|
||||
1. A column can only have one vector index at a time
|
||||
|
||||
## References
|
||||
|
||||
|
||||
@ -81,14 +81,11 @@ class ApiTool(Tool):
|
||||
return ToolProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
if self.runtime is None:
|
||||
raise ToolProviderCredentialValidationError("runtime not initialized")
|
||||
|
||||
headers = {}
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
credentials = self.runtime.credentials or {}
|
||||
|
||||
if "auth_type" not in credentials:
|
||||
raise ToolProviderCredentialValidationError("Missing auth_type")
|
||||
|
||||
|
||||
@ -62,7 +62,7 @@ class ToolProviderApiEntity(BaseModel):
|
||||
parameter.pop("input_schema", None)
|
||||
# -------------
|
||||
optional_fields = self.optional_field("server_url", self.server_url)
|
||||
if self.type == ToolProviderType.MCP.value:
|
||||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
return {
|
||||
|
||||
@ -959,7 +959,7 @@ class ToolManager:
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
import re
|
||||
|
||||
|
||||
def get_image_upload_file_ids(content):
|
||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
||||
matches = re.findall(pattern, content)
|
||||
image_upload_file_ids = []
|
||||
for match in matches:
|
||||
if match[1] == "file-preview":
|
||||
content_pattern = r"files/([^/]+)/file-preview"
|
||||
else:
|
||||
content_pattern = r"files/([^/]+)/image-preview"
|
||||
content_match = re.search(content_pattern, match[0])
|
||||
if content_match:
|
||||
image_upload_file_id = content_match.group(1)
|
||||
image_upload_file_ids.append(image_upload_file_id)
|
||||
return image_upload_file_ids
|
||||
@ -80,14 +80,14 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
article = extract_using_readabilipy(content)
|
||||
article = extract_using_readability(content)
|
||||
|
||||
if not article.text:
|
||||
return ""
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=article.title,
|
||||
author=article.auther,
|
||||
author=article.author,
|
||||
text=article.text,
|
||||
)
|
||||
|
||||
@ -97,15 +97,15 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
||||
@dataclass
|
||||
class Article:
|
||||
title: str
|
||||
auther: str
|
||||
author: str
|
||||
text: Sequence[dict]
|
||||
|
||||
|
||||
def extract_using_readabilipy(html: str):
|
||||
def extract_using_readability(html: str):
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
|
||||
article = Article(
|
||||
title=json_article.get("title") or "",
|
||||
auther=json_article.get("byline") or "",
|
||||
author=json_article.get("byline") or "",
|
||||
text=json_article.get("plain_text") or [],
|
||||
)
|
||||
|
||||
@ -113,7 +113,7 @@ def extract_using_readabilipy(html: str):
|
||||
|
||||
|
||||
def get_image_upload_file_ids(content):
|
||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
||||
pattern = r"!\[image\]\((https?://.*?(file-preview|image-preview))\)"
|
||||
matches = re.findall(pattern, content)
|
||||
image_upload_file_ids = []
|
||||
for match in matches:
|
||||
|
||||
@ -203,9 +203,6 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
raise ValueError("app not found")
|
||||
|
||||
app = db_providers.app
|
||||
if not app:
|
||||
raise ValueError("can not read app of workflow")
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
@ -123,7 +123,7 @@ class BillingService:
|
||||
return BillingService._send_request("GET", "/education/verify", params=params)
|
||||
|
||||
@classmethod
|
||||
def is_active(cls, account_id: str):
|
||||
def status(cls, account_id: str):
|
||||
params = {"account_id": account_id}
|
||||
return BillingService._send_request("GET", "/education/status", params=params)
|
||||
|
||||
|
||||
@ -294,6 +294,11 @@ class DatasetService:
|
||||
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def check_doc_form(dataset: Dataset, doc_form: str):
|
||||
if dataset.doc_form and doc_form != dataset.doc_form:
|
||||
raise ValueError("doc_form is different from the dataset doc_form.")
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_model_setting(dataset):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
@ -1265,6 +1270,8 @@ class DocumentService:
|
||||
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = "web",
|
||||
):
|
||||
# check doc_form
|
||||
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
|
||||
# check document limit
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
|
||||
@ -1,18 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tasks.mail_enterprise_task import send_enterprise_email_task
|
||||
|
||||
|
||||
class DifyMail(BaseModel):
|
||||
to: list[str]
|
||||
subject: str
|
||||
body: str
|
||||
substitutions: dict[str, str] = {}
|
||||
|
||||
|
||||
class EnterpriseMailService:
|
||||
@classmethod
|
||||
def send_mail(cls, mail: DifyMail):
|
||||
send_enterprise_email_task.delay(
|
||||
to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions
|
||||
)
|
||||
@ -5,7 +5,7 @@ import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import (
|
||||
|
||||
@ -6,7 +6,7 @@ import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
|
||||
|
||||
@ -11,7 +11,7 @@ from libs.email_i18n import get_email_i18n_service
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_enterprise_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
|
||||
def send_inner_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,620 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.model import EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
from services.app_service import AppService
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class TestSavedMessageService:
|
||||
"""Integration tests for SavedMessageService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.saved_message_service.MessageService") as mock_message_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Mock ModelManager for app creation
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
# Mock MessageService
|
||||
mock_message_service.get_message.return_value = None
|
||||
mock_message_service.pagination_by_last_id.return_value = None
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"message_service": mock_message_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant first
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_end_user(self, db_session_with_containers, app):
|
||||
"""
|
||||
Helper method to create a test end user for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: App instance to associate the end user with
|
||||
|
||||
Returns:
|
||||
EndUser: Created end user instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
external_user_id=fake.uuid4(),
|
||||
name=fake.name(),
|
||||
type="normal",
|
||||
session_id=fake.uuid4(),
|
||||
is_anonymous=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
def _create_test_message(self, db_session_with_containers, app, user):
|
||||
"""
|
||||
Helper method to create a test message for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: App instance to associate the message with
|
||||
user: User instance (Account or EndUser) to associate the message with
|
||||
|
||||
Returns:
|
||||
Message: Created message instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create a simple conversation first
|
||||
from models.model import Conversation
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
from_source="account" if hasattr(user, "current_tenant") else "end_user",
|
||||
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
|
||||
from_account_id=user.id if hasattr(user, "current_tenant") else None,
|
||||
name=fake.sentence(nb_words=3),
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
|
||||
# Create message
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
from_source="account" if hasattr(user, "current_tenant") else "end_user",
|
||||
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
|
||||
from_account_id=user.id if hasattr(user, "current_tenant") else None,
|
||||
inputs={},
|
||||
query=fake.sentence(nb_words=5),
|
||||
message=fake.text(max_nb_chars=100),
|
||||
answer=fake.text(max_nb_chars=200),
|
||||
message_tokens=50,
|
||||
answer_tokens=100,
|
||||
message_unit_price=0.001,
|
||||
answer_unit_price=0.002,
|
||||
total_price=0.003,
|
||||
currency="USD",
|
||||
status="success",
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
|
||||
return message
|
||||
|
||||
def test_pagination_by_last_id_success_with_account_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination by last ID with account user.
|
||||
|
||||
This test verifies:
|
||||
- Proper pagination with account user
|
||||
- Correct filtering by app_id and user
|
||||
- Proper role identification for account users
|
||||
- MessageService integration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create test messages
|
||||
message1 = self._create_test_message(db_session_with_containers, app, account)
|
||||
message2 = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Create saved messages
|
||||
saved_message1 = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message1.id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
saved_message2 = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message2.id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add_all([saved_message1, saved_message2])
|
||||
db.session.commit()
|
||||
|
||||
# Mock MessageService.pagination_by_last_id return value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
||||
mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=10, has_more=False)
|
||||
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=account, last_id=None, limit=10)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result.data == [message1, message2]
|
||||
assert result.limit == 10
|
||||
assert result.has_more is False
|
||||
|
||||
# Verify MessageService was called with correct parameters
|
||||
# Sort the IDs to handle database query order variations
|
||||
expected_include_ids = sorted([message1.id, message2.id])
|
||||
actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args
|
||||
actual_include_ids = sorted(actual_call.kwargs.get("include_ids", []))
|
||||
|
||||
assert actual_call.kwargs["app_model"] == app
|
||||
assert actual_call.kwargs["user"] == account
|
||||
assert actual_call.kwargs["last_id"] is None
|
||||
assert actual_call.kwargs["limit"] == 10
|
||||
assert actual_include_ids == expected_include_ids
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(saved_message1)
|
||||
db.session.refresh(saved_message2)
|
||||
assert saved_message1.id is not None
|
||||
assert saved_message2.id is not None
|
||||
assert saved_message1.created_by_role == "account"
|
||||
assert saved_message2.created_by_role == "account"
|
||||
|
||||
def test_pagination_by_last_id_success_with_end_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination by last ID with end user.
|
||||
|
||||
This test verifies:
|
||||
- Proper pagination with end user
|
||||
- Correct filtering by app_id and user
|
||||
- Proper role identification for end users
|
||||
- MessageService integration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
end_user = self._create_test_end_user(db_session_with_containers, app)
|
||||
|
||||
# Create test messages
|
||||
message1 = self._create_test_message(db_session_with_containers, app, end_user)
|
||||
message2 = self._create_test_message(db_session_with_containers, app, end_user)
|
||||
|
||||
# Create saved messages
|
||||
saved_message1 = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message1.id,
|
||||
created_by_role="end_user",
|
||||
created_by=end_user.id,
|
||||
)
|
||||
saved_message2 = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message2.id,
|
||||
created_by_role="end_user",
|
||||
created_by=end_user.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add_all([saved_message1, saved_message2])
|
||||
db.session.commit()
|
||||
|
||||
# Mock MessageService.pagination_by_last_id return value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
||||
mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=5, has_more=True)
|
||||
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = SavedMessageService.pagination_by_last_id(
|
||||
app_model=app, user=end_user, last_id="test_last_id", limit=5
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result.data == [message1, message2]
|
||||
assert result.limit == 5
|
||||
assert result.has_more is True
|
||||
|
||||
# Verify MessageService was called with correct parameters
|
||||
# Sort the IDs to handle database query order variations
|
||||
expected_include_ids = sorted([message1.id, message2.id])
|
||||
actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args
|
||||
actual_include_ids = sorted(actual_call.kwargs.get("include_ids", []))
|
||||
|
||||
assert actual_call.kwargs["app_model"] == app
|
||||
assert actual_call.kwargs["user"] == end_user
|
||||
assert actual_call.kwargs["last_id"] == "test_last_id"
|
||||
assert actual_call.kwargs["limit"] == 5
|
||||
assert actual_include_ids == expected_include_ids
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(saved_message1)
|
||||
db.session.refresh(saved_message2)
|
||||
assert saved_message1.id is not None
|
||||
assert saved_message2.id is not None
|
||||
assert saved_message1.created_by_role == "end_user"
|
||||
assert saved_message2.created_by_role == "end_user"
|
||||
|
||||
def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful save of a new message.
|
||||
|
||||
This test verifies:
|
||||
- Proper creation of new saved message
|
||||
- Correct database state after save
|
||||
- Proper relationship establishment
|
||||
- MessageService integration for message retrieval
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Mock MessageService.get_message return value
|
||||
mock_external_service_dependencies["message_service"].get_message.return_value = message
|
||||
|
||||
# Act: Execute the method under test
|
||||
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was created in database
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert saved_message is not None
|
||||
assert saved_message.app_id == app.id
|
||||
assert saved_message.message_id == message.id
|
||||
assert saved_message.created_by_role == "account"
|
||||
assert saved_message.created_by == account.id
|
||||
assert saved_message.created_at is not None
|
||||
|
||||
# Verify MessageService.get_message was called
|
||||
mock_external_service_dependencies["message_service"].get_message.assert_called_once_with(
|
||||
app_model=app, user=account, message_id=message.id
|
||||
)
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(saved_message)
|
||||
assert saved_message.id is not None
|
||||
|
||||
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when no user is provided.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing user
|
||||
- ValueError is raised when user is None
|
||||
- No database operations are performed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
|
||||
|
||||
assert "User is required" in str(exc_info.value)
|
||||
|
||||
# Verify no database operations were performed
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_messages = db.session.query(SavedMessage).all()
|
||||
assert len(saved_messages) == 0
|
||||
|
||||
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when saving message with no user.
|
||||
|
||||
This test verifies:
|
||||
- Method returns early when user is None
|
||||
- No database operations are performed
|
||||
- No exceptions are raised
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Act: Execute the method under test with None user
|
||||
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is None
|
||||
|
||||
# Verify no saved message was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert saved_message is None
|
||||
|
||||
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of an existing saved message.
|
||||
|
||||
This test verifies:
|
||||
- Proper deletion of existing saved message
|
||||
- Correct database state after deletion
|
||||
- No errors during deletion process
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Create a saved message first
|
||||
saved_message = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message.id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(saved_message)
|
||||
db.session.commit()
|
||||
|
||||
# Verify saved message exists
|
||||
assert (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was deleted from database
|
||||
deleted_saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert deleted_saved_message is None
|
||||
|
||||
# Verify database state
|
||||
db.session.commit()
|
||||
# The message should still exist, only the saved_message should be deleted
|
||||
assert db.session.query(Message).where(Message.id == message.id).first() is not None
|
||||
|
||||
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when no user is provided.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing user
|
||||
- ValueError is raised when user is None
|
||||
- No database operations are performed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
|
||||
|
||||
assert "User is required" in str(exc_info.value)
|
||||
|
||||
# Verify no database operations were performed for this specific test
|
||||
# Note: We don't check total count as other tests may have created data
|
||||
# Instead, we verify that the error was properly raised
|
||||
pass
|
||||
|
||||
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when saving message with no user.
|
||||
|
||||
This test verifies:
|
||||
- Method returns early when user is None
|
||||
- No database operations are performed
|
||||
- No exceptions are raised
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Act: Execute the method under test with None user
|
||||
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is None
|
||||
|
||||
# Verify no saved message was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert saved_message is None
|
||||
|
||||
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of an existing saved message.
|
||||
|
||||
This test verifies:
|
||||
- Proper deletion of existing saved message
|
||||
- Correct database state after deletion
|
||||
- No errors during deletion process
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Create a saved message first
|
||||
saved_message = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message.id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(saved_message)
|
||||
db.session.commit()
|
||||
|
||||
# Verify saved message exists
|
||||
assert (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was deleted from database
|
||||
deleted_saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert deleted_saved_message is None
|
||||
|
||||
# Verify database state
|
||||
db.session.commit()
|
||||
# The message should still exist, only the saved_message should be deleted
|
||||
assert db.session.query(Message).where(Message.id == message.id).first() is not None
|
||||
@ -0,0 +1,25 @@
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
|
||||
|
||||
def test_get_image_upload_file_ids():
|
||||
# should extract id from https + file-preview
|
||||
content = ""
|
||||
assert get_image_upload_file_ids(content) == ["abc123"]
|
||||
|
||||
# should extract id from http + image-preview
|
||||
content = ""
|
||||
assert get_image_upload_file_ids(content) == ["xyz789"]
|
||||
|
||||
# should not match invalid scheme 'htt://'
|
||||
content = ""
|
||||
assert get_image_upload_file_ids(content) == []
|
||||
|
||||
# should extract multiple ids in order
|
||||
content = """
|
||||
some text
|
||||

|
||||
middle
|
||||

|
||||
end
|
||||
"""
|
||||
assert get_image_upload_file_ids(content) == ["id1", "id2"]
|
||||
2
api/uv.lock
generated
2
api/uv.lock
generated
@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.11, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'",
|
||||
|
||||
Reference in New Issue
Block a user