mirror of
https://github.com/langgenius/dify.git
synced 2026-01-26 14:55:45 +08:00
Compare commits
109 Commits
fix/rag-in
...
1.1.0
| Author | SHA1 | Date | |
|---|---|---|---|
| ac80c04bd3 | |||
| fa9b767bf2 | |||
| abeaea4f79 | |||
| b65f2eb55f | |||
| 7d620ffd5e | |||
| 6f6ba2f025 | |||
| 33ba7e659b | |||
| 750ec55646 | |||
| 86d3fff666 | |||
| e91531fc23 | |||
| 2524f16525 | |||
| cefec44070 | |||
| 20376ca951 | |||
| 475b8d731e | |||
| 963b6f628a | |||
| 63ea6f1ecf | |||
| 947c9f70fb | |||
| 5e52d4d6b3 | |||
| 939dcb4c0a | |||
| 223ab5a38f | |||
| db7a37a111 | |||
| fe0d932f50 | |||
| 69fb0a4a28 | |||
| 04a0ae3aa9 | |||
| e5d6047fb4 | |||
| 9e782d4c1e | |||
| 98a4b3e78b | |||
| 2b4d1cf1db | |||
| fe76dfe1f8 | |||
| c3774bef7e | |||
| 695a7400a9 | |||
| e6a8800f66 | |||
| cee8731393 | |||
| 4ae94dc027 | |||
| 3a69a6a452 | |||
| f8f21ef7c0 | |||
| 0587eb4956 | |||
| 433374abea | |||
| 23ed3a520b | |||
| 5646442931 | |||
| 1a6298b6ea | |||
| bf9b572bc3 | |||
| cf72e53a10 | |||
| 98bd79f548 | |||
| 84a866028a | |||
| 10bd03611c | |||
| 7c27d4b202 | |||
| 8165d0b469 | |||
| e796937d02 | |||
| 49c952a631 | |||
| 5f9d236d22 | |||
| 59f5a82261 | |||
| f22a1adb8b | |||
| a8e8c37fdd | |||
| 37486a9cc6 | |||
| efebbffe96 | |||
| 5e035a4209 | |||
| 12fa517297 | |||
| 36ae0e5476 | |||
| 74f66d3119 | |||
| adfaee7ab5 | |||
| d37490adc3 | |||
| 087bb60b31 | |||
| 5019547d33 | |||
| 58f012f3de | |||
| b938c9b7f6 | |||
| 2b1facc7a6 | |||
| 1d5ea80a2b | |||
| 0415cc209d | |||
| 545e5cbcd6 | |||
| 1fab02c25a | |||
| 258736f505 | |||
| 0bc4da38fc | |||
| 037f200527 | |||
| b541792465 | |||
| eb9b256ee8 | |||
| 5d8b32a249 | |||
| c960b364c9 | |||
| b817036343 | |||
| 46036e6ce6 | |||
| 1ffda0dd34 | |||
| da01b460fe | |||
| 90a1508b87 | |||
| b07016113c | |||
| d8317fcf81 | |||
| a6bc642721 | |||
| b730f243dc | |||
| 71a57275ab | |||
| 41bf8d925f | |||
| 6d172498d1 | |||
| cad58658c2 | |||
| a58b990855 | |||
| b6b1903a37 | |||
| ed5596a8f4 | |||
| 49d0acd188 | |||
| 58a74fe1fb | |||
| a1ab4aec3d | |||
| f77f7e1437 | |||
| adda049265 | |||
| 9b2a9260ef | |||
| c8cc31af88 | |||
| d333de274f | |||
| 9e220d5d30 | |||
| 2cf0cb471f | |||
| 269ba6add9 | |||
| 78d460a6d1 | |||
| 3254018ddb | |||
| f2b7df94d7 | |||
| 59fd3aad31 |
13
.github/ISSUE_TEMPLATE/tracker.yml
vendored
Normal file
13
.github/ISSUE_TEMPLATE/tracker.yml
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
name: "👾 Tracker"
|
||||
description: For inner usages, please donot use this template.
|
||||
title: "[Tracker] "
|
||||
labels:
|
||||
- tracker
|
||||
body:
|
||||
- type: textarea
|
||||
id: content
|
||||
attributes:
|
||||
label: Blockers
|
||||
placeholder: "- [ ] ..."
|
||||
validations:
|
||||
required: true
|
||||
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
|
||||
29
.github/workflows/deploy-enterprise.yml
vendored
Normal file
29
.github/workflows/deploy-enterprise.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
name: Deploy Enterprise
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/enterprise"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.ENTERPRISE_SSH_HOST }}
|
||||
username: ${{ secrets.ENTERPRISE_SSH_USER }}
|
||||
password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }}
|
||||
script: |
|
||||
${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }}
|
||||
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@ -10,5 +10,6 @@ yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-com
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
|
||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
|
||||
1
.github/workflows/vdb-tests.yml
vendored
1
.github/workflows/vdb-tests.yml
vendored
@ -76,6 +76,7 @@ jobs:
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
opengauss
|
||||
chroma
|
||||
elasticsearch
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -202,3 +202,6 @@ api/.vscode
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
| [@jyong](https://github.com/JohnJyong) | RAG 流水线设计 |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | 构建 workflow 编排 |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | 让我们的前端更易用 |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验, 综合事项联系人 |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验,综合事项联系人 |
|
||||
| [@takatost](https://github.com/takatost) | 产品整体方向和架构 |
|
||||
|
||||
事项优先级:
|
||||
@ -47,7 +47,7 @@
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| 核心功能的 Bugs(例如无法登录、应用无法工作、安全漏洞) | 紧急 |
|
||||
| 非紧急 bugs, 性能提升 | 中等优先级 |
|
||||
| 小幅修复(错别字, 能正常工作但存在误导的 UI) | 低优先级 |
|
||||
| 小幅修复 (错别字,能正常工作但存在误导的 UI) | 低优先级 |
|
||||
|
||||
## 安装
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
广泛的 RAG 功能,涵盖从文档摄入到检索的所有内容,支持从 PDF、PPT 和其他常见文档格式中提取文本的开箱即用的支持。
|
||||
|
||||
**5. Agent 智能体**:
|
||||
您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了50多种内置工具,如谷歌搜索、DALL·E、Stable Diffusion 和 WolframAlpha 等。
|
||||
您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了 50 多种内置工具,如谷歌搜索、DALL·E、Stable Diffusion 和 WolframAlpha 等。
|
||||
|
||||
**6. LLMOps**:
|
||||
随时间监视和分析应用程序日志和性能。您可以根据生产数据和标注持续改进提示、数据集和模型。
|
||||
@ -112,7 +112,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
<td align="center">仅限 OpenAI</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG引擎</td>
|
||||
<td align="center">RAG 引擎</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
@ -234,7 +234,7 @@ docker compose up -d
|
||||
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。
|
||||
|
||||
> 我们正在寻找贡献者来帮助将Dify翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。
|
||||
> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。
|
||||
|
||||
**Contributors**
|
||||
|
||||
|
||||
@ -137,7 +137,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -298,6 +298,14 @@ OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
|
||||
# openGauss configuration
|
||||
OPENGAUSS_HOST=127.0.0.1
|
||||
OPENGAUSS_PORT=6600
|
||||
OPENGAUSS_USER=postgres
|
||||
OPENGAUSS_PASSWORD=Dify@123
|
||||
OPENGAUSS_DATABASE=dify
|
||||
OPENGAUSS_MIN_CONNECTION=1
|
||||
OPENGAUSS_MAX_CONNECTION=5
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
@ -378,6 +386,7 @@ HTTP_REQUEST_MAX_READ_TIMEOUT=600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
@ -444,4 +453,4 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
# Lockout duration in seconds
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
|
||||
@ -56,8 +56,6 @@ RUN \
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
fonts-noto-cjk \
|
||||
# install a package to improve the accuracy of guessing mime type and file extension
|
||||
media-types \
|
||||
# install libmagic to support the use of python-magic guess MIMETYPE
|
||||
|
||||
@ -160,11 +160,17 @@ def migrate_annotation_vector_database():
|
||||
while True:
|
||||
try:
|
||||
# get apps info
|
||||
per_page = 50
|
||||
apps = (
|
||||
App.query.filter(App.status == "normal")
|
||||
db.session.query(App)
|
||||
.filter(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.paginate(page=page, per_page=50)
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
if not apps:
|
||||
break
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
@ -267,6 +273,7 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.OPENGAUSS,
|
||||
}
|
||||
lower_collection_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
|
||||
@ -332,6 +332,11 @@ class HttpConfig(BaseSettings):
|
||||
default=1 * 1024 * 1024,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY: bool = Field(
|
||||
description="Enable or disable SSL verification for HTTP requests",
|
||||
default=True,
|
||||
)
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field(
|
||||
description="Maximum number of retries for network requests (SSRF)",
|
||||
default=3,
|
||||
|
||||
@ -26,6 +26,7 @@ from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
from .vdb.myscale_config import MyScaleConfig
|
||||
from .vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
from .vdb.opengauss_config import OpenGaussConfig
|
||||
from .vdb.opensearch_config import OpenSearchConfig
|
||||
from .vdb.oracle_config import OracleConfig
|
||||
from .vdb.pgvector_config import PGVectorConfig
|
||||
@ -281,5 +282,6 @@ class MiddlewareConfig(
|
||||
LindormConfig,
|
||||
OceanBaseVectorConfig,
|
||||
BaiduVectorDBConfig,
|
||||
OpenGaussConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
45
api/configs/middleware/vdb/opengauss_config.py
Normal file
45
api/configs/middleware/vdb/opengauss_config.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenGaussConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for OpenGauss
|
||||
"""
|
||||
|
||||
OPENGAUSS_HOST: Optional[str] = Field(
|
||||
description="Hostname or IP address of the OpenGauss server(e.g., 'localhost')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENGAUSS_PORT: PositiveInt = Field(
|
||||
description="Port number on which the OpenGauss server is listening (default is 6600)",
|
||||
default=6600,
|
||||
)
|
||||
|
||||
OPENGAUSS_USER: Optional[str] = Field(
|
||||
description="Username for authenticating with the OpenGauss database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENGAUSS_PASSWORD: Optional[str] = Field(
|
||||
description="Password for authenticating with the OpenGauss database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENGAUSS_DATABASE: Optional[str] = Field(
|
||||
description="Name of the OpenGauss database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENGAUSS_MIN_CONNECTION: PositiveInt = Field(
|
||||
description="Min connection of the OpenGauss database",
|
||||
default=1,
|
||||
)
|
||||
|
||||
OPENGAUSS_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Max connection of the OpenGauss database",
|
||||
default=5,
|
||||
)
|
||||
@ -43,3 +43,8 @@ class PGVectorConfig(BaseSettings):
|
||||
description="Max connection of the PostgreSQL database",
|
||||
default=5,
|
||||
)
|
||||
|
||||
PGVECTOR_PG_BIGM: bool = Field(
|
||||
description="Whether to use pg_bigm module for full text search",
|
||||
default=False,
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.0.0",
|
||||
default="1.1.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -81,6 +81,7 @@ from .datasets import (
|
||||
datasets_segments,
|
||||
external,
|
||||
hit_testing,
|
||||
metadata,
|
||||
website,
|
||||
)
|
||||
|
||||
|
||||
@ -316,7 +316,7 @@ class AppTraceApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
# add app trace
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enabled", type=bool, required=True, location="json")
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -13,6 +15,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
@ -24,7 +27,7 @@ from models.account import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -439,10 +442,38 @@ class PublishedWorkflowApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflow = workflow_service.publish_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@ -564,37 +595,193 @@ class PublishedAllWorkflowApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument("user_id", type=str, required=False, location="args")
|
||||
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
page = args.get("page")
|
||||
limit = args.get("limit")
|
||||
page = int(args.get("page", 1))
|
||||
limit = int(args.get("limit", 10))
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit)
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=page,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
)
|
||||
|
||||
return {"items": workflows, "page": page, "limit": limit, "has_more": has_more}
|
||||
return {
|
||||
"items": workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = workflow_service.update_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def delete(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Delete workflow
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
workflow_service.delete_workflow(
|
||||
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
|
||||
)
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
except DraftWorkflowDeletionError as e:
|
||||
abort(400, description=str(e))
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
|
||||
return None, 204
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/config",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowTaskStopApi,
|
||||
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
|
||||
WorkflowDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(WorkflowDraftRunLoopNodeApi, "/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
|
||||
api.add_resource(PublishedAllWorkflowApi, "/apps/<uuid:app_id>/workflows")
|
||||
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
|
||||
WorkflowDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigsApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
ConvertToWorkflowApi,
|
||||
"/apps/<uuid:app_id>/convert-to-workflow",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowByIdApi,
|
||||
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
|
||||
|
||||
@ -1,13 +1,18 @@
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask_restful.inputs import int_range # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunStatus
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
@ -24,17 +29,38 @@ class WorkflowAppLogApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model, args=args
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword=args.keyword,
|
||||
status=args.status,
|
||||
created_at_before=args.created_at__before,
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs")
|
||||
|
||||
@ -122,7 +122,7 @@ class DataSourceNotionListApi(Resource):
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
|
||||
documents = session.execute(
|
||||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
||||
@ -10,7 +10,12 @@ from controllers.console import api
|
||||
from controllers.console.apikey import api_key_fields, api_key_list
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -96,6 +101,7 @@ class DatasetListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
@ -178,6 +184,10 @@ class DatasetApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
data["embedding_model_provider"] = str(provider_id)
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
@ -210,6 +220,7 @@ class DatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -276,7 +287,11 @@ class DatasetApi(Resource):
|
||||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
if (
|
||||
data.get("indexing_technique") == "high_quality"
|
||||
and data.get("embedding_model_provider") is not None
|
||||
and data.get("embedding_model") is not None
|
||||
):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
@ -313,6 +328,7 @@ class DatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
@ -647,6 +663,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
| VectorType.OPENGAUSS
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
@ -690,6 +707,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.LINDORM
|
||||
| VectorType.OPENGAUSS
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@ -26,6 +26,7 @@ from controllers.console.datasets.error import (
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
@ -242,6 +243,7 @@ class DatasetDocumentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(documents_and_batch_fields)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
@ -297,6 +299,7 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -320,9 +323,10 @@ class DatasetInitApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -617,7 +621,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
@ -678,7 +682,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"doc_type": document.doc_type,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
"doc_metadata": document.doc_metadata_details,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
@ -694,13 +698,14 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if action == "pause":
|
||||
@ -730,6 +735,7 @@ class DocumentDeleteApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
@ -763,8 +769,8 @@ class DocumentMetadataApi(DocumentResource):
|
||||
doc_type = req_data.get("doc_type")
|
||||
doc_metadata = req_data.get("doc_metadata")
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if doc_type is None or doc_metadata is None:
|
||||
@ -798,6 +804,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -893,6 +900,7 @@ class DocumentPauseApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""pause document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@ -925,6 +933,7 @@ class DocumentRecoverApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""recover document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@ -954,6 +963,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ from controllers.console.datasets.error import (
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_knowledge_limit_check,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
@ -106,6 +107,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -121,8 +123,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
segment_ids = request.args.getlist("segment_id")
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
@ -137,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -148,8 +151,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@ -191,6 +194,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -202,7 +206,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
@ -240,6 +244,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -276,8 +281,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
@ -299,6 +304,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -319,8 +325,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
@ -336,6 +342,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -402,6 +409,7 @@ class ChildChunkAddApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -420,7 +428,7 @@ class ChildChunkAddApi(Resource):
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
@ -499,6 +507,7 @@ class ChildChunkAddApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -519,8 +528,8 @@ class ChildChunkAddApi(Resource):
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
@ -542,6 +551,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -569,8 +579,8 @@ class ChildChunkUpdateApi(Resource):
|
||||
).first()
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
@ -586,6 +596,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -613,8 +624,8 @@ class ChildChunkUpdateApi(Resource):
|
||||
).first()
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
@ -2,7 +2,11 @@ from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from libs.login import login_required
|
||||
|
||||
|
||||
@ -10,6 +14,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
|
||||
155
api/controllers/console/datasets/metadata.py
Normal file
155
api/controllers/console/datasets/metadata.py
Normal file
@ -0,0 +1,155 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class DatasetMetadataCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def post(self, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return metadata, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
|
||||
|
||||
class DatasetMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
return metadata, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return 200
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id, action):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return 200
|
||||
|
||||
|
||||
class DocumentMetadataEditApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
|
||||
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
@ -26,6 +26,7 @@ from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
@ -68,6 +69,11 @@ class TenantListApi(Resource):
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
|
||||
for tenant in tenants:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
if features.billing.enabled:
|
||||
tenant.plan = features.billing.subscription.plan
|
||||
else:
|
||||
tenant.plan = "sandbox"
|
||||
if tenant.id == current_user.current_tenant_id:
|
||||
tenant.current = True # Set current=True for current tenant
|
||||
return {"workspaces": marshal(tenants, tenants_fields)}, 200
|
||||
@ -82,28 +88,20 @@ class WorkspaceListApi(Resource):
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(page=args["page"], per_page=args["limit"])
|
||||
|
||||
tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(
|
||||
page=args["page"], per_page=args["limit"], error_out=False
|
||||
)
|
||||
has_more = False
|
||||
if len(tenants.items) == args["limit"]:
|
||||
current_page_first_tenant = tenants[-1]
|
||||
rest_count = (
|
||||
db.session.query(Tenant)
|
||||
.filter(
|
||||
Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
total = db.session.query(Tenant).count()
|
||||
if tenants.has_next:
|
||||
has_more = True
|
||||
|
||||
return {
|
||||
"data": marshal(tenants.items, workspace_fields),
|
||||
"has_more": has_more,
|
||||
"limit": args["limit"],
|
||||
"page": args["page"],
|
||||
"total": total,
|
||||
"total": tenants.total,
|
||||
}, 200
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import wraps
|
||||
|
||||
from flask import abort, request
|
||||
@ -8,6 +9,8 @@ from flask_login import current_user # type: ignore
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
@ -67,7 +70,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
abort(403, "The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
|
||||
abort(
|
||||
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
|
||||
)
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places,
|
||||
# so we need to check the source of the request from datasets
|
||||
@ -112,6 +117,41 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{current_user.current_tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
|
||||
request_count = redis_client.zcard(key)
|
||||
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
db.session.add(rate_limit_log)
|
||||
db.session.commit()
|
||||
abort(
|
||||
403, "Sorry, you have reached the knowledge base request rate limit of your subscription."
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
||||
@ -10,7 +10,7 @@ from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import feedback_fields, retriever_resource_fields
|
||||
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
@ -19,20 +19,6 @@ from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(Resource):
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
|
||||
from flask_restful.inputs import int_range # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
@ -25,7 +27,7 @@ from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
@ -125,17 +127,34 @@ class WorkflowAppLogApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("created_at__before", type=str, location="args")
|
||||
parser.add_argument("created_at__after", type=str, location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model, args=args
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword=args.keyword,
|
||||
status=args.status,
|
||||
created_at_before=args.created_at__before,
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
@ -13,8 +14,10 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -139,6 +142,43 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{api_token.tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
|
||||
request_count = redis_client.zcard(key)
|
||||
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=api_token.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
db.session.add(rate_limit_log)
|
||||
db.session.commit()
|
||||
raise Forbidden(
|
||||
"Sorry, you have reached the knowledge base request rate limit of your subscription."
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
|
||||
def validate_dataset_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
MetadataFilteringCondition,
|
||||
ModelConfig,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DatasetService
|
||||
@ -78,6 +83,15 @@ class DatasetConfigManager:
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
else None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -89,11 +103,22 @@ class DatasetConfigManager:
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=dataset_configs.get("top_k", 4),
|
||||
score_threshold=dataset_configs.get("score_threshold"),
|
||||
score_threshold=dataset_configs.get("score_threshold")
|
||||
if dataset_configs.get("score_threshold_enabled", False)
|
||||
else None,
|
||||
reranking_model=dataset_configs.get("reranking_model"),
|
||||
weights=dataset_configs.get("weights"),
|
||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.model import AppMode
|
||||
|
||||
@ -135,6 +136,55 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
mode: LLMMode
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Conditon detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
@ -171,6 +221,9 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
reranking_model: Optional[dict] = None
|
||||
weights: Optional[dict] = None
|
||||
reranking_enabled: Optional[bool] = True
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
|
||||
|
||||
class DatasetEntity(BaseModel):
|
||||
|
||||
@ -17,17 +17,15 @@ class FileUploadConfigManager:
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get("enabled"):
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
|
||||
data = {
|
||||
"image_config": {
|
||||
"number_limits": file_upload_dict["number_limits"],
|
||||
"transfer_methods": transform_methods,
|
||||
}
|
||||
file_upload_dict["image_config"] = {
|
||||
"number_limits": file_upload_dict.get("number_limits", 1),
|
||||
"transfer_methods": transform_methods,
|
||||
}
|
||||
|
||||
if is_vision:
|
||||
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
|
||||
file_upload_dict["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "high")
|
||||
|
||||
return FileUploadConfig.model_validate(data)
|
||||
return FileUploadConfig.model_validate(file_upload_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
|
||||
@ -151,7 +151,7 @@ class BaseAppGenerator:
|
||||
|
||||
def gen():
|
||||
for message in generator:
|
||||
if isinstance(message, (Mapping, dict)):
|
||||
if isinstance(message, Mapping | dict):
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
else:
|
||||
yield f"event: {message}\n\n"
|
||||
|
||||
@ -17,7 +17,11 @@ from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
@ -141,6 +145,7 @@ class AppRunner:
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
@ -167,6 +172,7 @@ class AppRunner:
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
else:
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
@ -201,6 +207,7 @@ class AppRunner:
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
@ -50,6 +51,16 @@ class ChatAppRunner(AppRunner):
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
image_detail_config = (
|
||||
application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
application_generate_entity.file_upload_config
|
||||
and application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
@ -85,6 +96,7 @@ class ChatAppRunner(AppRunner):
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
# moderation
|
||||
@ -168,6 +180,7 @@ class ChatAppRunner(AppRunner):
|
||||
hit_callback=hit_callback,
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
@ -182,6 +195,7 @@ class ChatAppRunner(AppRunner):
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
|
||||
@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
@ -43,6 +44,16 @@ class CompletionAppRunner(AppRunner):
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
image_detail_config = (
|
||||
application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
application_generate_entity.file_upload_config
|
||||
and application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
@ -66,6 +77,7 @@ class CompletionAppRunner(AppRunner):
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
# moderation
|
||||
@ -127,6 +139,7 @@ class CompletionAppRunner(AppRunner):
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
@ -140,6 +153,7 @@ class CompletionAppRunner(AppRunner):
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
|
||||
@ -7,7 +7,6 @@ from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import or_
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def _get_custom_provider_credentials(self) -> Provider | None:
|
||||
"""
|
||||
Get custom provider credentials.
|
||||
"""
|
||||
# get provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(provider_names),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return provider_record
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
or_(
|
||||
Provider.provider_name == model_provider_id.provider_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_record = self._get_custom_provider_credentials()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
or_(
|
||||
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_record = self._get_custom_provider_credentials()
|
||||
|
||||
# delete provider
|
||||
if provider_record:
|
||||
@ -349,6 +335,33 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return None
|
||||
|
||||
def _get_custom_model_credentials(
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
) -> ProviderModel | None:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
"""
|
||||
# get provider model
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name.in_(provider_names),
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return provider_model_record
|
||||
|
||||
def custom_model_credentials_validate(
|
||||
self, model_type: ModelType, model: str, credentials: dict
|
||||
) -> tuple[ProviderModel | None, dict]:
|
||||
@ -361,16 +374,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_model_record = self._get_custom_model_credentials(model_type, model)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_model_record = self._get_custom_model_credentials(model_type, model)
|
||||
|
||||
# delete provider model
|
||||
if provider_model_record:
|
||||
@ -475,6 +470,26 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
|
||||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
return (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Enable model.
|
||||
@ -482,16 +497,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
model_setting = self._get_provider_model_setting(model_type, model)
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = True
|
||||
@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
model_setting = self._get_provider_model_setting(model_type, model)
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = False
|
||||
@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
return self._get_provider_model_setting(model_type, model)
|
||||
|
||||
def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
|
||||
"""
|
||||
Get load balancing config.
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
return (
|
||||
db.session.query(ProviderModelSetting)
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
load_balancing_config_count = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if load_balancing_config_count <= 1:
|
||||
raise ValueError("Model load balancing configuration must be more than 1.")
|
||||
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
model_setting = self._get_provider_model_setting(model_type, model)
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = True
|
||||
@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel):
|
||||
return
|
||||
|
||||
# get preferred provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
preferred_model_provider = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == self.provider.provider,
|
||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -63,7 +63,9 @@ class File(BaseModel):
|
||||
extension: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
size: int = -1,
|
||||
storage_key: str,
|
||||
storage_key: Optional[str] = None,
|
||||
dify_model_identity: Optional[str] = FILE_MODEL_IDENTITY,
|
||||
url: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -76,8 +78,10 @@ class File(BaseModel):
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
dify_model_identity=dify_model_identity,
|
||||
url=url,
|
||||
)
|
||||
self._storage_key = storage_key
|
||||
self._storage_key = str(storage_key)
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
|
||||
@ -11,6 +11,19 @@ from configs import dify_config
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
|
||||
try:
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
|
||||
if http_request_node_ssl_verify_lower == "true":
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
elif http_request_node_ssl_verify_lower == "false":
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = False
|
||||
else:
|
||||
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
|
||||
except NameError:
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
@ -39,17 +52,17 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client:
|
||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
with httpx.Client(mounts=proxy_mounts) as client:
|
||||
with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
else:
|
||||
with httpx.Client() as client:
|
||||
with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
|
||||
@ -493,7 +493,7 @@ If inputting a combination of text and images, the images need to be constructed
|
||||
The base class for all Role message bodies, used only for parameter declaration and cannot be initialized.
|
||||
|
||||
```python
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
class PromptMessage(BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
|
||||
@ -533,7 +533,7 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||
所有 Role 消息体的基类,仅作为参数声明用,不可初始化。
|
||||
|
||||
```python
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
class PromptMessage(BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
|
||||
@ -31,11 +31,9 @@ __all__ = [
|
||||
"ModelPropertyKey",
|
||||
"MultiModalPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageContentType",
|
||||
"PromptMessageRole",
|
||||
"PromptMessageRole",
|
||||
"PromptMessageTool",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Optional
|
||||
@ -119,7 +118,7 @@ class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
class PromptMessage(BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
|
||||
@ -80,7 +80,7 @@ class AIModel(BaseModel):
|
||||
)
|
||||
)
|
||||
elif isinstance(invoke_error, InvokeError):
|
||||
return invoke_error(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
|
||||
return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
|
||||
else:
|
||||
return error
|
||||
|
||||
|
||||
@ -214,6 +214,8 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["trace_instance"],
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
)
|
||||
if not decrypt_trace_config:
|
||||
return None
|
||||
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
|
||||
return tracing_instance
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
@ -46,7 +46,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
model_parameters=payload.completion_params,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=payload.stream or True,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
@ -64,7 +64,21 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
else:
|
||||
if response.usage:
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
return response
|
||||
|
||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=response.model,
|
||||
prompt_messages=response.prompt_messages,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
|
||||
@ -147,7 +147,7 @@ def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: An
|
||||
init frontend parameter by rule
|
||||
"""
|
||||
parameter_value = value
|
||||
if not parameter_value and parameter_value != 0:
|
||||
if not parameter_value and parameter_value != 0 and type != PluginParameterType.TOOLS_SELECTOR:
|
||||
# get default value
|
||||
parameter_value = rule.default
|
||||
if not parameter_value and rule.required:
|
||||
|
||||
@ -46,6 +46,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
|
||||
@ -59,6 +60,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||
prompt_messages = self._get_chat_model_prompt_messages(
|
||||
@ -70,6 +72,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
@ -84,6 +87,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get completion model prompt messages.
|
||||
@ -124,7 +128,9 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
for file in files:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
@ -142,6 +148,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get chat model prompt messages.
|
||||
@ -197,7 +204,9 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file in files:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
@ -209,19 +218,25 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
# get last user message content and add files
|
||||
prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
|
||||
last_message.content = prompt_message_contents
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
|
||||
for file in files:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
elif query:
|
||||
|
||||
@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.file import file_manager
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
SystemPromptMessage,
|
||||
@ -60,6 +61,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
@ -74,6 +76,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
else:
|
||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||
@ -85,6 +88,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
return prompt_messages, stops
|
||||
@ -175,6 +179,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
files: Sequence["File"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
@ -204,9 +209,9 @@ class SimplePromptTransform(PromptTransform):
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_messages.append(self.get_last_user_message(query, files))
|
||||
prompt_messages.append(self.get_last_user_message(query, files, image_detail_config))
|
||||
else:
|
||||
prompt_messages.append(self.get_last_user_message(prompt, files))
|
||||
prompt_messages.append(self.get_last_user_message(prompt, files, image_detail_config))
|
||||
|
||||
return prompt_messages, None
|
||||
|
||||
@ -220,6 +225,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
files: Sequence["File"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
# get prompt
|
||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
||||
@ -262,14 +268,21 @@ class SimplePromptTransform(PromptTransform):
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
|
||||
return [self.get_last_user_message(prompt, files)], stops
|
||||
return [self.get_last_user_message(prompt, files, image_detail_config)], stops
|
||||
|
||||
def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage:
|
||||
def get_last_user_message(
|
||||
self,
|
||||
prompt: str,
|
||||
files: Sequence["File"],
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> UserPromptMessage:
|
||||
if files:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
for file in files:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
|
||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
|
||||
@ -149,6 +149,11 @@ class ProviderManager:
|
||||
provider_name = provider_entity.provider
|
||||
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
||||
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
||||
provider_id_entity = ModelProviderID(provider_name)
|
||||
if provider_id_entity.is_langgenius():
|
||||
provider_model_records.extend(
|
||||
provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, [])
|
||||
)
|
||||
|
||||
# Convert to custom configuration
|
||||
custom_configuration = self._to_custom_configuration(
|
||||
@ -190,6 +195,20 @@ class ProviderManager:
|
||||
provider_name
|
||||
)
|
||||
|
||||
provider_id_entity = ModelProviderID(provider_name)
|
||||
|
||||
if provider_id_entity.is_langgenius():
|
||||
if provider_model_settings is not None:
|
||||
provider_model_settings.extend(
|
||||
provider_name_to_provider_model_settings_dict.get(provider_id_entity.provider_name, [])
|
||||
)
|
||||
if provider_load_balancing_configs is not None:
|
||||
provider_load_balancing_configs.extend(
|
||||
provider_name_to_provider_load_balancing_model_configs_dict.get(
|
||||
provider_id_entity.provider_name, []
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to model settings
|
||||
model_settings = self._to_model_settings(
|
||||
provider_entity=provider_entity,
|
||||
@ -207,7 +226,7 @@ class ProviderManager:
|
||||
model_settings=model_settings,
|
||||
)
|
||||
|
||||
provider_configurations[str(ModelProviderID(provider_name))] = provider_configuration
|
||||
provider_configurations[str(provider_id_entity)] = provider_configuration
|
||||
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
|
||||
@ -88,16 +88,17 @@ class Jieba(BaseKeyword):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
k = kwargs.get("top_k", 4)
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
|
||||
.first()
|
||||
segment_query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment = segment_query.first()
|
||||
|
||||
if segment:
|
||||
documents.append(
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
@ -42,6 +41,7 @@ class RetrievalService:
|
||||
reranking_model: Optional[dict] = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
if not query:
|
||||
return []
|
||||
@ -65,6 +65,7 @@ class RetrievalService:
|
||||
top_k=top_k,
|
||||
all_documents=all_documents,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
@ -80,6 +81,7 @@ class RetrievalService:
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
@ -131,7 +133,14 @@ class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
all_documents: list,
|
||||
exceptions: list,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
@ -140,7 +149,10 @@ class RetrievalService:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
keyword = Keyword(dataset=dataset)
|
||||
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
|
||||
|
||||
documents = keyword.search(
|
||||
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
|
||||
)
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
exceptions.append(str(e))
|
||||
@ -157,6 +169,7 @@ class RetrievalService:
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
@ -171,6 +184,7 @@ class RetrievalService:
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={"group_id": [dataset.id]},
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
if documents:
|
||||
@ -243,7 +257,7 @@ class RetrievalService:
|
||||
|
||||
@staticmethod
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
return json.dumps(query).strip('"')
|
||||
return query.replace('"', '\\"')
|
||||
|
||||
@classmethod
|
||||
def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]:
|
||||
@ -277,6 +291,8 @@ class RetrievalService:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
if not dataset_document:
|
||||
continue
|
||||
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
|
||||
@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector):
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_vector(query_vector)
|
||||
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
@ -194,6 +194,13 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = "WHERE 1=1"
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
@ -202,7 +209,7 @@ class AnalyticdbVectorBySql:
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
||||
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
@ -220,12 +227,19 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
|
||||
@ -123,11 +123,21 @@ class BaiduVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
filter=f"document_id IN ({document_ids})",
|
||||
)
|
||||
else:
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
res = self._db.table(self._collection_name).search(
|
||||
anns=anns,
|
||||
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||
|
||||
@ -95,7 +95,15 @@ class ChromaVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
results: QueryResult = collection.query(
|
||||
query_embeddings=query_vector,
|
||||
n_results=kwargs.get("top_k", 4),
|
||||
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
|
||||
)
|
||||
else:
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
# Check if results contain data
|
||||
|
||||
@ -117,6 +117,9 @@ class ElasticSearchVector(BaseVector):
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
num_candidates = math.ceil(top_k * 1.5)
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
|
||||
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||
|
||||
@ -145,6 +148,9 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
|
||||
@ -168,7 +168,12 @@ class LindormVectorStore(BaseVector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filters = []
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
|
||||
|
||||
try:
|
||||
params = {}
|
||||
if self._using_ugc:
|
||||
@ -206,7 +211,10 @@ class LindormVectorStore(BaseVector):
|
||||
should = kwargs.get("should")
|
||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
filters = kwargs.get("filter")
|
||||
filters = kwargs.get("filter", [])
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
routing = self._routing
|
||||
full_text_query = default_text_search_query(
|
||||
query_text=query,
|
||||
|
||||
@ -228,12 +228,18 @@ class MilvusVector(BaseVector):
|
||||
"""
|
||||
Search for documents by vector similarity.
|
||||
"""
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] in ({document_ids})'
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
@ -249,6 +255,11 @@ class MilvusVector(BaseVector):
|
||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||
return []
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] in ({document_ids})'
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
@ -256,6 +267,7 @@ class MilvusVector(BaseVector):
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
|
||||
@ -125,12 +125,18 @@ class MyScaleVector(BaseVector):
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
|
||||
@ -154,6 +154,11 @@ class OceanBaseVector(BaseVector):
|
||||
return []
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = None
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
|
||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||
if ef_search != self._hnsw_ef_search:
|
||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
@ -167,6 +172,7 @@ class OceanBaseVector(BaseVector):
|
||||
distance_func=func.l2_distance,
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
where_clause=where_clause,
|
||||
)
|
||||
docs = []
|
||||
for text, metadata, distance in cur:
|
||||
|
||||
240
api/core/rag/datasource/vdb/opengauss/opengauss.py
Normal file
240
api/core/rag/datasource/vdb/opengauss/opengauss.py
Normal file
@ -0,0 +1,240 @@
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class OpenGaussConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config OPENGAUSS_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config OPENGAUSS_PORT is required")
|
||||
if not values["user"]:
|
||||
raise ValueError("config OPENGAUSS_USER is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config OPENGAUSS_PASSWORD is required")
|
||||
if not values["database"]:
|
||||
raise ValueError("config OPENGAUSS_DATABASE is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config OPENGAUSS_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config OPENGAUSS_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
SQL_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
embedding vector({dimension}) NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
|
||||
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||
"""
|
||||
|
||||
|
||||
class OpenGauss(BaseVector):
|
||||
def __init__(self, collection_name: str, config: OpenGaussConfig):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = f"embedding_{collection_name}"
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.OPENGAUSS
|
||||
|
||||
def _create_connection_pool(self, config: OpenGaussConfig):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
config.min_connection,
|
||||
config.max_connection,
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
database=config.database,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
embeddings[i],
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_values(
|
||||
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
|
||||
)
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
|
||||
# Scenario 1: extract a document fails, resulting in a table not being created.
|
||||
# Then clicking the retry button triggers a delete operation on an empty list.
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector.
|
||||
|
||||
:param query_vector: The input vector to search for similar items.
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||
f" ORDER BY distance LIMIT {top_k}",
|
||||
(json.dumps(query_vector),),
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
|
||||
docs = []
|
||||
|
||||
for record in cur:
|
||||
metadata, text, score = record
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
if dimension <= 2000:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class OpenGaussFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenGauss:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENGAUSS, collection_name))
|
||||
|
||||
return OpenGauss(
|
||||
collection_name=collection_name,
|
||||
config=OpenGaussConfig(
|
||||
host=dify_config.OPENGAUSS_HOST or "localhost",
|
||||
port=dify_config.OPENGAUSS_PORT,
|
||||
user=dify_config.OPENGAUSS_USER or "postgres",
|
||||
password=dify_config.OPENGAUSS_PASSWORD or "",
|
||||
database=dify_config.OPENGAUSS_DATABASE or "dify",
|
||||
min_connection=dify_config.OPENGAUSS_MIN_CONNECTION,
|
||||
max_connection=dify_config.OPENGAUSS_MAX_CONNECTION,
|
||||
),
|
||||
)
|
||||
@ -154,6 +154,9 @@ class OpenSearchVector(BaseVector):
|
||||
"size": kwargs.get("top_k", 4),
|
||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
@ -179,6 +182,9 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
|
||||
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
|
||||
@ -201,10 +201,15 @@ class OracleVector(BaseVector):
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||
f" ORDER BY distance fetch first {top_k} rows only",
|
||||
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
@ -257,9 +262,15 @@ class OracleVector(BaseVector):
|
||||
if token not in stop_words:
|
||||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"select meta, text, embedding FROM {self.table_name}"
|
||||
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
||||
f"order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)],
|
||||
)
|
||||
docs = []
|
||||
|
||||
@ -189,6 +189,9 @@ class PGVectoRS(BaseVector):
|
||||
.limit(kwargs.get("top_k", 4))
|
||||
.order_by("distance")
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
|
||||
res = session.execute(stmt)
|
||||
results = [(row[0], row[1]) for row in res]
|
||||
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.errors
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -25,6 +27,7 @@ class PGVectorConfig(BaseModel):
|
||||
database: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
pg_bigm: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -62,12 +65,18 @@ CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
|
||||
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX_PG_BIGM = """
|
||||
CREATE INDEX IF NOT EXISTS bigm_idx ON {table_name}
|
||||
USING gin (text gin_bigm_ops);
|
||||
"""
|
||||
|
||||
|
||||
class PGVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: PGVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = f"embedding_{collection_name}"
|
||||
self.pg_bigm = config.pg_bigm
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.PGVECTOR
|
||||
@ -140,7 +149,14 @@ class PGVector(BaseVector):
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
except psycopg2.errors.UndefinedTable:
|
||||
# table not exists
|
||||
logging.warning(f"Table {self.table_name} not found, skipping delete operation.")
|
||||
return
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
@ -155,10 +171,18 @@ class PGVector(BaseVector):
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||
f" {where_clause}"
|
||||
f" ORDER BY distance LIMIT {top_k}",
|
||||
(json.dumps(query_vector),),
|
||||
)
|
||||
@ -174,17 +198,37 @@ class PGVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
if self.pg_bigm:
|
||||
cur.execute("SET pg_bigm.similarity_limit TO 0.000001")
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE text =%% unistr(%s)
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
|
||||
docs = []
|
||||
|
||||
@ -214,6 +258,9 @@ class PGVector(BaseVector):
|
||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
if dimension <= 2000:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
if self.pg_bigm:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS pg_bigm")
|
||||
cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
@ -237,5 +284,6 @@ class PGVectorFactory(AbstractVectorFactory):
|
||||
database=dify_config.PGVECTOR_DATABASE or "postgres",
|
||||
min_connection=dify_config.PGVECTOR_MIN_CONNECTION,
|
||||
max_connection=dify_config.PGVECTOR_MAX_CONNECTION,
|
||||
pg_bigm=dify_config.PGVECTOR_PG_BIGM,
|
||||
),
|
||||
)
|
||||
|
||||
@ -286,27 +286,26 @@ class QdrantVector(BaseVector):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchAny(any=ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
@ -331,6 +330,15 @@ class QdrantVector(BaseVector):
|
||||
),
|
||||
],
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
if filter.must:
|
||||
filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
@ -377,6 +385,15 @@ class QdrantVector(BaseVector):
|
||||
),
|
||||
]
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
if scroll_filter.must:
|
||||
scroll_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
|
||||
@ -223,8 +223,12 @@ class RelytVector(BaseVector):
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = kwargs.get("filter", {})
|
||||
if document_ids_filter:
|
||||
filter["document_id"] = document_ids_filter
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
|
||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
|
||||
)
|
||||
|
||||
# Organize results.
|
||||
@ -246,9 +250,9 @@ class RelytVector(BaseVector):
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
conditions = [
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
|
||||
f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
|
||||
if len(value) > 1
|
||||
else f"metadata->>{key!r} = {value[0]!r}"
|
||||
else f"metadata->>'{key!r}' = {value[0]!r}"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
filter_condition = f"WHERE {' AND '.join(conditions)}"
|
||||
|
||||
@ -145,11 +145,16 @@ class TencentVector(BaseVector):
|
||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||
res = self._db.collection(self._collection_name).search(
|
||||
vectors=[query_vector],
|
||||
filter=filter,
|
||||
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
|
||||
@ -326,6 +326,18 @@ class TidbOnQdrantVector(BaseVector):
|
||||
),
|
||||
],
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
should_conditions = []
|
||||
for document_id_filter in document_ids_filter:
|
||||
should_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchValue(value=document_id_filter),
|
||||
)
|
||||
)
|
||||
if should_conditions:
|
||||
filter.should = should_conditions # type: ignore
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
@ -368,6 +380,18 @@ class TidbOnQdrantVector(BaseVector):
|
||||
)
|
||||
]
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
should_conditions = []
|
||||
for document_id_filter in document_ids_filter:
|
||||
should_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchValue(value=document_id_filter),
|
||||
)
|
||||
)
|
||||
if should_conditions:
|
||||
scroll_filter.should = should_conditions # type: ignore
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
|
||||
@ -196,6 +196,11 @@ class TiDBVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
|
||||
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(f"""
|
||||
@ -206,6 +211,7 @@ class TiDBVector(BaseVector):
|
||||
text,
|
||||
{tidb_dist_func}(vector, :query_vector_str) AS distance
|
||||
FROM {self._collection_name}
|
||||
{where_clause}
|
||||
ORDER BY distance ASC
|
||||
LIMIT :top_k
|
||||
) t
|
||||
|
||||
@ -88,7 +88,20 @@ class UpstashVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f"document_id in ({document_ids})"
|
||||
else:
|
||||
filter = ""
|
||||
result = self.index.query(
|
||||
vector=query_vector,
|
||||
top_k=top_k,
|
||||
include_metadata=True,
|
||||
include_data=True,
|
||||
include_vectors=False,
|
||||
filter=filter,
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in result:
|
||||
|
||||
@ -148,6 +148,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
return OceanBaseVectorFactory
|
||||
case VectorType.OPENGAUSS:
|
||||
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
|
||||
|
||||
return OpenGaussFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@ -24,3 +24,4 @@ class VectorType(StrEnum):
|
||||
UPSTASH = "upstash"
|
||||
TIDB_ON_QDRANT = "tidb_on_qdrant"
|
||||
OCEANBASE = "oceanbase"
|
||||
OPENGAUSS = "opengauss"
|
||||
|
||||
@ -177,7 +177,11 @@ class VikingDBVector(BaseVector):
|
||||
query_vector, limit=kwargs.get("top_k", 4)
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(results, score_threshold)
|
||||
docs = self._get_search_res(results, score_threshold)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
|
||||
return docs
|
||||
|
||||
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
||||
if len(results) == 0:
|
||||
|
||||
@ -187,8 +187,10 @@ class WeaviateVector(BaseVector):
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
|
||||
vector = {"vector": query_vector}
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||
query_obj = query_obj.with_where(where_filter)
|
||||
result = (
|
||||
query_obj.with_near_vector(vector)
|
||||
.with_limit(kwargs.get("top_k", 4))
|
||||
@ -233,8 +235,10 @@ class WeaviateVector(BaseVector):
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||
query_obj = query_obj.with_where(where_filter)
|
||||
query_obj = query_obj.with_additional(["vector"])
|
||||
properties = ["text"]
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
|
||||
|
||||
45
api/core/rag/entities/metadata_entities.py
Normal file
45
api/core/rag/entities/metadata_entities.py
Normal file
@ -0,0 +1,45 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Conditon detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataCondition(BaseModel):
|
||||
"""
|
||||
Metadata Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
15
api/core/rag/index_processor/constant/built_in_field.py
Normal file
15
api/core/rag/index_processor/constant/built_in_field.py
Normal file
@ -0,0 +1,15 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BuiltInField(str, Enum):
|
||||
document_name = "document_name"
|
||||
uploader = "uploader"
|
||||
upload_date = "upload_date"
|
||||
last_update_date = "last_update_date"
|
||||
source = "source"
|
||||
|
||||
|
||||
class MetadataDataSource(Enum):
|
||||
upload_file = "file_upload"
|
||||
website_crawl = "website"
|
||||
notion_import = "notion"
|
||||
@ -1,35 +1,61 @@
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
from collections import Counter
|
||||
from typing import Any, Optional, cast
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Integer, and_, or_, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
|
||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
MetadataFilteringCondition,
|
||||
ModelConfig,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from core.rag.retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
METADATA_FILTER_COMPLETION_PROMPT,
|
||||
METADATA_FILTER_SYSTEM_PROMPT,
|
||||
METADATA_FILTER_USER_PROMPT_1,
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
@ -59,6 +85,7 @@ class DatasetRetrieval:
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
message_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
inputs: Optional[Mapping[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset.
|
||||
@ -116,6 +143,22 @@ class DatasetRetrieval:
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
if inputs:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
else:
|
||||
inputs = {}
|
||||
available_datasets_ids = [dataset.id for dataset in available_datasets]
|
||||
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
||||
available_datasets_ids,
|
||||
query,
|
||||
tenant_id,
|
||||
user_id,
|
||||
retrieve_config.metadata_filtering_mode, # type: ignore
|
||||
retrieve_config.metadata_model_config, # type: ignore
|
||||
retrieve_config.metadata_filtering_conditions,
|
||||
inputs,
|
||||
)
|
||||
|
||||
all_documents = []
|
||||
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
@ -130,6 +173,8 @@ class DatasetRetrieval:
|
||||
model_config,
|
||||
planning_strategy,
|
||||
message_id,
|
||||
metadata_filter_document_ids,
|
||||
metadata_condition,
|
||||
)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
all_documents = self.multiple_retrieve(
|
||||
@ -146,6 +191,8 @@ class DatasetRetrieval:
|
||||
retrieve_config.weights,
|
||||
retrieve_config.reranking_enabled or True,
|
||||
message_id,
|
||||
metadata_filter_document_ids,
|
||||
metadata_condition,
|
||||
)
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
@ -239,6 +286,8 @@ class DatasetRetrieval:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
@ -279,6 +328,7 @@ class DatasetRetrieval:
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = Document(
|
||||
@ -293,6 +343,15 @@ class DatasetRetrieval:
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
else:
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
return []
|
||||
document_ids_filter = None
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
return []
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
@ -324,6 +383,7 @@ class DatasetRetrieval:
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
@ -348,6 +408,8 @@ class DatasetRetrieval:
|
||||
weights: Optional[dict[str, Any]] = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
return []
|
||||
@ -387,6 +449,16 @@ class DatasetRetrieval:
|
||||
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if dataset.provider != "external":
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
continue
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
continue
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
@ -395,6 +467,8 @@ class DatasetRetrieval:
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
"metadata_condition": metadata_condition,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
@ -433,30 +507,33 @@ class DatasetRetrieval:
|
||||
dataset_document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
).first()
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = ChildChunk.query.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
).first()
|
||||
if child_chunk:
|
||||
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = ChildChunk.query.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
).first()
|
||||
if child_chunk:
|
||||
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
@ -490,7 +567,16 @@ class DatasetRetrieval:
|
||||
db.session.add_all(dataset_queries)
|
||||
db.session.commit()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
def _retriever(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
all_documents: list,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
@ -503,6 +589,7 @@ class DatasetRetrieval:
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = Document(
|
||||
@ -543,6 +630,7 @@ class DatasetRetrieval:
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
@ -730,3 +818,340 @@ class DatasetRetrieval:
|
||||
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
|
||||
)
|
||||
return filter_documents[:top_k] if top_k else filter_documents
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self,
|
||||
dataset_ids: list,
|
||||
query: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
metadata_filtering_mode: str,
|
||||
metadata_model_config: ModelConfig,
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
|
||||
inputs: dict,
|
||||
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
||||
document_query = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
filters = [] # type: ignore
|
||||
metadata_condition = None
|
||||
if metadata_filtering_mode == "disabled":
|
||||
return None, None
|
||||
elif metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, tenant_id, user_id, metadata_model_config
|
||||
)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for filter in automatic_metadata_filters:
|
||||
self._process_metadata_filter_func(
|
||||
filter.get("condition"), # type: ignore
|
||||
filter.get("metadata_name"), # type: ignore
|
||||
filter.get("value"),
|
||||
filters, # type: ignore
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore
|
||||
conditions=conditions,
|
||||
)
|
||||
elif metadata_filtering_mode == "manual":
|
||||
if metadata_filtering_conditions:
|
||||
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
|
||||
for condition in metadata_filtering_conditions.conditions: # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value or condition.comparison_operator in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
|
||||
filters = self._process_metadata_filter_func(
|
||||
condition.comparison_operator, metadata_name, expected_value, filters
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if metadata_filtering_conditions.logical_operator == "or": # type: ignore
|
||||
document_query = document_query.filter(or_(*filters))
|
||||
else:
|
||||
document_query = document_query.filter(and_(*filters))
|
||||
documents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
for document in documents:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition
|
||||
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
return str(inputs.get(key, f"{{{{{key}}}}}"))
|
||||
|
||||
pattern = re.compile(r"\{\{(\w+)\}\}")
|
||||
return pattern.sub(replacer, text)
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
if metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._get_prompt_template(
|
||||
model_config=model_config,
|
||||
mode=metadata_model_config.mode,
|
||||
metadata_fields=all_metadata_fields,
|
||||
query=query or "",
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
invoke_result = cast(
|
||||
Generator[LLMResult, None, None],
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_config.parameters,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
),
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
if "metadata_map" in result_text_json:
|
||||
metadata_map = result_text_json["metadata_map"]
|
||||
for item in metadata_map:
|
||||
if item.get("metadata_field_name") in all_metadata_fields:
|
||||
automatic_metadata_filters.append(
|
||||
{
|
||||
"metadata_name": item.get("metadata_field_name"),
|
||||
"value": item.get("metadata_field_value"),
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list):
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(
|
||||
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
|
||||
)
|
||||
case "not contains":
|
||||
filters.append(
|
||||
(text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
|
||||
key=metadata_name, value=f"%{value}%"
|
||||
)
|
||||
)
|
||||
case "start with":
|
||||
filters.append(
|
||||
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
|
||||
)
|
||||
|
||||
case "end with":
|
||||
filters.append(
|
||||
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
|
||||
)
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
|
||||
else:
|
||||
filters.append(
|
||||
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
|
||||
)
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
|
||||
else:
|
||||
filters.append(
|
||||
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
|
||||
)
|
||||
case "empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
|
||||
case "not empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
|
||||
case "before" | "<":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
|
||||
case "≤" | ">=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
|
||||
case _:
|
||||
pass
|
||||
return filters
|
||||
|
||||
def _fetch_model_config(
|
||||
self, tenant_id: str, model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
if model is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ValueError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise ValueError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = model.completion_params
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _get_prompt_template(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
|
||||
):
|
||||
model_mode = ModelMode.value_of(mode)
|
||||
input_text = query
|
||||
|
||||
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
||||
if model_mode == ModelMode.CHAT:
|
||||
prompt_template = []
|
||||
system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
|
||||
prompt_template.append(system_prompt_messages)
|
||||
user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
|
||||
prompt_template.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = ChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_template.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
|
||||
prompt_template.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = ChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_template.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=METADATA_FILTER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
prompt_template.append(user_prompt_message_3)
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
prompt_template = CompletionModelPromptTemplate(
|
||||
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query=query or "",
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
||||
66
api/core/rag/retrieval/template_prompts.py
Normal file
66
api/core/rag/retrieval/template_prompts.py
Normal file
@ -0,0 +1,66 @@
|
||||
METADATA_FILTER_SYSTEM_PROMPT = """
|
||||
### Job Description',
|
||||
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||
### Task
|
||||
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
""" # noqa: E501
|
||||
|
||||
METADATA_FILTER_USER_PROMPT_1 = """
|
||||
{ "input_text": "I want to know which company’s email address test@example.com is?",
|
||||
"metadata_fields": ["filename", "email", "phone", "address"]
|
||||
}
|
||||
"""
|
||||
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
|
||||
```json
|
||||
{"metadata_map": [
|
||||
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
METADATA_FILTER_USER_PROMPT_2 = """
|
||||
{"input_text": "What are the movies with a score of more than 9 in 2024?",
|
||||
"metadata_fields": ["name", "year", "rating", "country"]}
|
||||
"""
|
||||
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
|
||||
```json
|
||||
{"metadata_map": [
|
||||
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
|
||||
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
|
||||
]}
|
||||
```
|
||||
"""
|
||||
|
||||
METADATA_FILTER_USER_PROMPT_3 = """
|
||||
'{{"input_text": "{input_text}",',
|
||||
'"metadata_fields": {metadata_fields}}}'
|
||||
"""
|
||||
|
||||
METADATA_FILTER_COMPLETION_PROMPT = """
|
||||
### Job Description
|
||||
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||
### Task
|
||||
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Example
|
||||
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||
<example>
|
||||
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
|
||||
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
|
||||
</example>
|
||||
### User Input
|
||||
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
|
||||
### Assistant Output
|
||||
""" # noqa: E501
|
||||
@ -76,38 +76,74 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
|
||||
def recursive_split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = self._separators[-1]
|
||||
for _s in self._separators:
|
||||
new_separators = []
|
||||
|
||||
for i, _s in enumerate(self._separators):
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if _s in text:
|
||||
separator = _s
|
||||
new_separators = self._separators[i + 1 :]
|
||||
break
|
||||
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
if separator == " ":
|
||||
splits = text.split()
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = "" if self._keep_separator else separator
|
||||
s_lens = self._length_function(splits)
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
_good_splits_lengths = []
|
||||
other_info = self.recursive_split_text(s)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
if _separator != "":
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
_good_splits_lengths = []
|
||||
if not new_separators:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
else:
|
||||
current_part = ""
|
||||
current_length = 0
|
||||
overlap_part = ""
|
||||
overlap_part_length = 0
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
if current_length + s_len <= self._chunk_size - self._chunk_overlap:
|
||||
current_part += s
|
||||
current_length += s_len
|
||||
elif current_length + s_len <= self._chunk_size:
|
||||
current_part += s
|
||||
current_length += s_len
|
||||
overlap_part += s
|
||||
overlap_part_length += s_len
|
||||
else:
|
||||
final_chunks.append(current_part)
|
||||
current_part = overlap_part + s
|
||||
current_length = s_len + overlap_part_length
|
||||
overlap_part = ""
|
||||
overlap_part_length = 0
|
||||
if current_part:
|
||||
final_chunks.append(current_part)
|
||||
|
||||
return final_chunks
|
||||
|
||||
@ -1,25 +0,0 @@
|
||||
# Tools
|
||||
|
||||
This module implements built-in tools used in Agent Assistants and Workflows within Dify. You could define and display your own tools in this module, without modifying the frontend logic. This decoupling allows for easier horizontal scaling of Dify's capabilities.
|
||||
|
||||
## Feature Introduction
|
||||
|
||||
The tools provided for Agents and Workflows are currently divided into two categories:
|
||||
- `Built-in Tools` are internally implemented within our product and are hardcoded for use in Agents and Workflows.
|
||||
- `Api-Based Tools` leverage third-party APIs for implementation. You don't need to code to integrate these -- simply provide interface definitions in formats like `OpenAPI` , `Swagger`, or the `OpenAI-plugin` on the front-end.
|
||||
|
||||
### Built-in Tool Providers
|
||||

|
||||
|
||||
### API Tool Providers
|
||||

|
||||
|
||||
## Tool Integration
|
||||
|
||||
To enable developers to build flexible and powerful tools, we provide two guides:
|
||||
|
||||
### [Quick Integration 👈🏻](./docs/en_US/tool_scale_out.md)
|
||||
Quick integration aims at quickly getting you up to speed with tool integration by walking over an example Google Search tool.
|
||||
|
||||
### [Advanced Integration 👈🏻](./docs/en_US/advanced_scale_out.md)
|
||||
Advanced integration will offer a deeper dive into the module interfaces, and explain how to implement more complex capabilities, such as generating images, combining multiple tools, and managing the flow of parameters, images, and files between different tools.
|
||||
@ -1,27 +0,0 @@
|
||||
# Tools
|
||||
|
||||
该模块提供了各Agent和Workflow中会使用的内置工具的调用、鉴权接口,并为 Dify 提供了统一的工具供应商的信息和凭据表单规则。
|
||||
|
||||
- 一方面将工具和业务代码解耦,方便开发者对模型横向扩展,
|
||||
- 另一方面提供了只需在后端定义供应商和工具,即可在前端页面直接展示,无需修改前端逻辑。
|
||||
|
||||
## 功能介绍
|
||||
|
||||
对于给Agent和Workflow提供的工具,我们当前将其分为两类:
|
||||
- `Built-in Tools` 内置工具,即Dify内部实现的工具,通过硬编码的方式提供给Agent和Workflow使用。
|
||||
- `Api-Based Tools` 基于API的工具,即通过调用第三方API实现的工具,`Api-Based Tool`不需要再额外定义,只需提供`OpenAPI` `Swagger` `OpenAI plugin`等接口文档即可。
|
||||
|
||||
### 内置工具供应商
|
||||

|
||||
|
||||
### API工具供应商
|
||||

|
||||
|
||||
## 工具接入
|
||||
为了实现更灵活更强大的功能,Tools提供了一系列的接口,帮助开发者快速构建想要的工具,本文作为开发者的入门指南,将会以[快速接入](./docs/zh_Hans/tool_scale_out.md)和[高级接入](./docs/zh_Hans/advanced_scale_out.md)两部分介绍如何接入工具。
|
||||
|
||||
### [快速接入 👈🏻](./docs/zh_Hans/tool_scale_out.md)
|
||||
快速接入可以帮助你在10~20分钟内完成工具的接入,但是这种接入方式只能实现简单的功能,如果你想要实现更复杂的功能,可以参考下面的高级接入。
|
||||
|
||||
### [高级接入 👈🏻](./docs/zh_Hans/advanced_scale_out.md)
|
||||
高级接入将介绍如何实现更复杂的功能配置,包括实现图生图、实现多个工具的组合、实现参数、图片、文件在多个工具之间的流转。
|
||||
@ -1,31 +0,0 @@
|
||||
# Tools
|
||||
|
||||
このモジュールは、Difyのエージェントアシスタントやワークフローで使用される組み込みツールを実装しています。このモジュールでは、フロントエンドのロジックを変更することなく、独自のツールを定義し表示することができます。この分離により、Difyの機能を容易に水平方向にスケールアウトできます。
|
||||
|
||||
## 機能紹介
|
||||
|
||||
エージェントとワークフロー向けに提供されるツールは、現在2つのカテゴリーに分類されています。
|
||||
|
||||
- `Built-in Tools`はDify内部で実装され、エージェントとワークフローで使用するためにハードコードされています。
|
||||
- `Api-Based Tools`はサードパーティのAPIを利用して実装されています。これらを統合するためのコーディングは不要で、フロントエンドで
|
||||
`OpenAPI`, `Swagger`または`OpenAI-plugin`などの形式でインターフェース定義を提供するだけです。
|
||||
|
||||
### 組み込みツールプロバイダー
|
||||
|
||||

|
||||
|
||||
### APIツールプロバイダー
|
||||
|
||||

|
||||
|
||||
## ツールの統合
|
||||
|
||||
開発者が柔軟で強力なツールを構築できるよう、2つのガイドを提供しています。
|
||||
|
||||
### [クイック統合 👈🏻](./docs/ja_JP/tool_scale_out.md)
|
||||
|
||||
クイック統合は、Google検索ツールの例を通じて、ツール統合の基本をすばやく理解できるようにすることを目的としています。
|
||||
|
||||
### [高度な統合 👈🏻](./docs/ja_JP/advanced_scale_out.md)
|
||||
|
||||
高度な統合では、モジュールインターフェースについてより深く掘り下げ、画像生成、複数ツールの組み合わせ、異なるツール間でのパラメーター、画像、ファイルのフロー管理など、より複雑な機能の実装方法を説明します。
|
||||
@ -179,6 +179,18 @@ class ApiTool(Tool):
|
||||
for content_type in self.api_bundle.openapi["requestBody"]["content"]:
|
||||
headers["Content-Type"] = content_type
|
||||
body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"]
|
||||
|
||||
# handle ref schema
|
||||
if "$ref" in body_schema:
|
||||
ref_path = body_schema["$ref"].split("/")
|
||||
ref_name = ref_path[-1]
|
||||
if (
|
||||
"components" in self.api_bundle.openapi
|
||||
and "schemas" in self.api_bundle.openapi["components"]
|
||||
):
|
||||
if ref_name in self.api_bundle.openapi["components"]["schemas"]:
|
||||
body_schema = self.api_bundle.openapi["components"]["schemas"][ref_name]
|
||||
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
@ -186,6 +198,8 @@ class ApiTool(Tool):
|
||||
if property.get("format") == "binary":
|
||||
f = parameters[name]
|
||||
files.append((name, (f.filename, download(f), f.mime_type)))
|
||||
elif "$ref" in property:
|
||||
body[name] = parameters[name]
|
||||
else:
|
||||
# convert type
|
||||
body[name] = self._convert_body_property_type(property, parameters[name])
|
||||
|
||||
@ -1,278 +0,0 @@
|
||||
# Advanced Tool Integration
|
||||
|
||||
Before starting with this advanced guide, please make sure you have a basic understanding of the tool integration process in Dify. Check out [Quick Integration](./tool_scale_out.md) for a quick runthrough.
|
||||
|
||||
## Tool Interface
|
||||
|
||||
We have defined a series of helper methods in the `Tool` class to help developers quickly build more complex tools.
|
||||
|
||||
### Message Return
|
||||
|
||||
Dify supports various message types such as `text`, `link`, `json`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces.
|
||||
|
||||
Please note, some parameters in the following interfaces will be introduced in later sections.
|
||||
|
||||
#### Image URL
|
||||
You only need to pass the URL of the image, and Dify will automatically download the image and return it to the user.
|
||||
|
||||
```python
|
||||
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:return: the image message
|
||||
"""
|
||||
```
|
||||
|
||||
#### Link
|
||||
If you need to return a link, you can use the following interface.
|
||||
|
||||
```python
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:return: the link message
|
||||
"""
|
||||
```
|
||||
|
||||
#### Text
|
||||
If you need to return a text message, you can use the following interface.
|
||||
|
||||
```python
|
||||
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
||||
:param text: the text of the message
|
||||
:return: the text message
|
||||
"""
|
||||
```
|
||||
|
||||
#### File BLOB
|
||||
If you need to return the raw data of a file, such as images, audio, video, PPT, Word, Excel, etc., you can use the following interface.
|
||||
|
||||
- `blob` The raw data of the file, of bytes type
|
||||
- `meta` The metadata of the file, if you know the type of the file, it is best to pass a `mime_type`, otherwise Dify will use `application/octet-stream` as the default type
|
||||
|
||||
```python
|
||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:return: the blob message
|
||||
"""
|
||||
```
|
||||
|
||||
#### JSON
|
||||
If you need to return a formatted JSON, you can use the following interface. This is commonly used for data transmission between nodes in a workflow, of course, in agent mode, most LLM are also able to read and understand JSON.
|
||||
|
||||
- `object` A Python dictionary object will be automatically serialized into JSON
|
||||
|
||||
```python
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
```
|
||||
|
||||
### Shortcut Tools
|
||||
|
||||
In large model applications, we have two common needs:
|
||||
- First, summarize a long text in advance, and then pass the summary content to the LLM to prevent the original text from being too long for the LLM to handle
|
||||
- The content obtained by the tool is a link, and the web page information needs to be crawled before it can be returned to the LLM
|
||||
|
||||
To help developers quickly implement these two needs, we provide the following two shortcut tools.
|
||||
|
||||
#### Text Summary Tool
|
||||
|
||||
This tool takes in an user_id and the text to be summarized, and returns the summarized text. Dify will use the default model of the current workspace to summarize the long text.
|
||||
|
||||
```python
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
"""
|
||||
summary the content
|
||||
|
||||
:param user_id: the user id
|
||||
:param content: the content
|
||||
:return: the summary
|
||||
"""
|
||||
```
|
||||
|
||||
#### Web Page Crawling Tool
|
||||
|
||||
This tool takes in web page link to be crawled and a user_agent (which can be empty), and returns a string containing the information of the web page. The `user_agent` is an optional parameter that can be used to identify the tool. If not passed, Dify will use the default `user_agent`.
|
||||
|
||||
```python
|
||||
def get_url(self, url: str, user_agent: str = None) -> str:
|
||||
"""
|
||||
get url
|
||||
""" the crawled result
|
||||
```
|
||||
|
||||
### Variable Pool
|
||||
|
||||
We have introduced a variable pool in `Tool` to store variables, files, etc. generated during the tool's operation. These variables can be used by other tools during the tool's operation.
|
||||
|
||||
Next, we will use `DallE3` and `Vectorizer.AI` as examples to introduce how to use the variable pool.
|
||||
|
||||
- `DallE3` is an image generation tool that can generate images based on text. Here, we will let `DallE3` generate a logo for a coffee shop
|
||||
- `Vectorizer.AI` is a vector image conversion tool that can convert images into vector images, so that the images can be infinitely enlarged without distortion. Here, we will convert the PNG icon generated by `DallE3` into a vector image, so that it can be truly used by designers.
|
||||
|
||||
#### DallE3
|
||||
First, we use DallE3. After creating the image, we save the image to the variable pool. The code is as follows:
|
||||
|
||||
```python
|
||||
from typing import Any, Dict, List, Union
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from base64 import b64decode
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
)
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
|
||||
# call openapi dalle3
|
||||
response = client.images.generate(
|
||||
prompt=prompt, model='dall-e-3',
|
||||
size='1024x1024', n=1, style='vivid', quality='standard',
|
||||
response_format='b64_json'
|
||||
)
|
||||
|
||||
result = []
|
||||
for image in response.data:
|
||||
# Save all images to the variable pool through the save_as parameter. The variable name is self.VARIABLE_KEY.IMAGE.value. If new images are generated later, they will overwrite the previous images.
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
Note that we used `self.VARIABLE_KEY.IMAGE.value` as the variable name of the image. In order for developers' tools to cooperate with each other, we defined this `KEY`. You can use it freely, or you can choose not to use this `KEY`. Passing a custom KEY is also acceptable.
|
||||
|
||||
#### Vectorizer.AI
|
||||
Next, we use Vectorizer.AI to convert the PNG icon generated by DallE3 into a vector image. Let's go through the functions we defined here. The code is as follows:
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool
|
||||
"""
|
||||
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParameter]:
|
||||
"""
|
||||
Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list
|
||||
"""
|
||||
|
||||
|
||||
def is_tool_available(self) -> bool:
|
||||
"""
|
||||
Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here
|
||||
"""
|
||||
```
|
||||
|
||||
Next, let's implement these three functions
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key_name = self.runtime.credentials.get('api_key_name', None)
|
||||
api_key_value = self.runtime.credentials.get('api_key_value', None)
|
||||
|
||||
if not api_key_name or not api_key_value:
|
||||
raise ToolProviderCredentialValidationError('Please input api key name and value')
|
||||
|
||||
# Get image_id, the definition of image_id can be found in get_runtime_parameters
|
||||
image_id = tool_parameters.get('image_id', '')
|
||||
if not image_id:
|
||||
return self.create_text_message('Please input image id')
|
||||
|
||||
# Get the image generated by DallE from the variable pool
|
||||
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
|
||||
if not image_binary:
|
||||
return self.create_text_message('Image not found, please request user to generate image firstly.')
|
||||
|
||||
# Generate vector image
|
||||
response = post(
|
||||
'https://vectorizer.ai/api/v1/vectorize',
|
||||
files={ 'image': image_binary },
|
||||
data={ 'mode': 'test' },
|
||||
auth=(api_key_name, api_key_value),
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
return [
|
||||
self.create_text_message('the vectorized svg is saved as an image.'),
|
||||
self.create_blob_message(blob=response.content,
|
||||
meta={'mime_type': 'image/svg+xml'})
|
||||
]
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParameter]:
|
||||
"""
|
||||
override the runtime parameters
|
||||
"""
|
||||
# Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml.
|
||||
return [
|
||||
ToolParameter.get_simple_instance(
|
||||
name='image_id',
|
||||
llm_description=f'the image id that you want to vectorize, \
|
||||
and the image id should be specified in \
|
||||
{[i.name for i in self.list_default_image_variables()]}',
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=True,
|
||||
options=[i.name for i in self.list_default_image_variables()]
|
||||
)
|
||||
]
|
||||
|
||||
def is_tool_available(self) -> bool:
|
||||
# Only when there are images in the variable pool, the LLM needs to use this tool
|
||||
return len(self.list_default_image_variables()) > 0
|
||||
```
|
||||
|
||||
It's worth noting that we didn't actually use `image_id` here. We assumed that there must be an image in the default variable pool when calling this tool, so we directly used `image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)` to get the image. In cases where the model's capabilities are weak, we recommend developers to do the same, which can effectively improve fault tolerance and avoid the model passing incorrect parameters.
|
||||
@ -1,248 +0,0 @@
|
||||
# Quick Tool Integration
|
||||
|
||||
Here, we will use GoogleSearch as an example to demonstrate how to quickly integrate a tool.
|
||||
|
||||
## 1. Prepare the Tool Provider yaml
|
||||
|
||||
### Introduction
|
||||
|
||||
This yaml declares a new tool provider, and includes information like the provider's name, icon, author, and other details that are fetched by the frontend for display.
|
||||
|
||||
### Example
|
||||
|
||||
We need to create a `google` module (folder) under `core/tools/provider/builtin`, and create `google.yaml`. The name must be consistent with the module name.
|
||||
|
||||
Subsequently, all operations related to this tool will be carried out under this module.
|
||||
|
||||
```yaml
|
||||
identity: # Basic information of the tool provider
|
||||
author: Dify # Author
|
||||
name: google # Name, unique, no duplication with other providers
|
||||
label: # Label for frontend display
|
||||
en_US: Google # English label
|
||||
zh_Hans: Google # Chinese label
|
||||
description: # Description for frontend display
|
||||
en_US: Google # English description
|
||||
zh_Hans: Google # Chinese description
|
||||
icon: icon.svg # Icon, needs to be placed in the _assets folder of the current module
|
||||
tags:
|
||||
- search
|
||||
|
||||
```
|
||||
|
||||
- The `identity` field is mandatory, it contains the basic information of the tool provider, including author, name, label, description, icon, etc.
|
||||
- The icon needs to be placed in the `_assets` folder of the current module, you can refer to [here](../../provider/builtin/google/_assets/icon.svg).
|
||||
- The `tags` field is optional, it is used to classify the provider, and the frontend can filter the provider according to the tag, for all tags, they have been listed below:
|
||||
|
||||
```python
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = 'search'
|
||||
IMAGE = 'image'
|
||||
VIDEOS = 'videos'
|
||||
WEATHER = 'weather'
|
||||
FINANCE = 'finance'
|
||||
DESIGN = 'design'
|
||||
TRAVEL = 'travel'
|
||||
SOCIAL = 'social'
|
||||
NEWS = 'news'
|
||||
MEDICAL = 'medical'
|
||||
PRODUCTIVITY = 'productivity'
|
||||
EDUCATION = 'education'
|
||||
BUSINESS = 'business'
|
||||
ENTERTAINMENT = 'entertainment'
|
||||
UTILITIES = 'utilities'
|
||||
OTHER = 'other'
|
||||
```
|
||||
|
||||
## 2. Prepare Provider Credentials
|
||||
|
||||
Google, as a third-party tool, uses the API provided by SerpApi, which requires an API Key to use. This means that this tool needs a credential to use. For tools like `wikipedia`, there is no need to fill in the credential field, you can refer to [here](../../provider/builtin/wikipedia/wikipedia.yaml).
|
||||
|
||||
After configuring the credential field, the effect is as follows:
|
||||
|
||||
```yaml
|
||||
identity:
|
||||
author: Dify
|
||||
name: google
|
||||
label:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
description:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
icon: icon.svg
|
||||
credentials_for_provider: # Credential field
|
||||
serpapi_api_key: # Credential field name
|
||||
type: secret-input # Credential field type
|
||||
required: true # Required or not
|
||||
label: # Credential field label
|
||||
en_US: SerpApi API key # English label
|
||||
zh_Hans: SerpApi API key # Chinese label
|
||||
placeholder: # Credential field placeholder
|
||||
en_US: Please input your SerpApi API key # English placeholder
|
||||
zh_Hans: 请输入你的 SerpApi API key # Chinese placeholder
|
||||
help: # Credential field help text
|
||||
en_US: Get your SerpApi API key from SerpApi # English help text
|
||||
zh_Hans: 从 SerpApi 获取您的 SerpApi API key # Chinese help text
|
||||
url: https://serpapi.com/manage-api-key # Credential field help link
|
||||
|
||||
```
|
||||
|
||||
- `type`: Credential field type, currently can be either `secret-input`, `text-input`, or `select` , corresponding to password input box, text input box, and drop-down box, respectively. If set to `secret-input`, it will mask the input content on the frontend, and the backend will encrypt the input content.
|
||||
|
||||
## 3. Prepare Tool yaml
|
||||
|
||||
A provider can have multiple tools, each tool needs a yaml file to describe, this file contains the basic information, parameters, output, etc. of the tool.
|
||||
|
||||
Still taking GoogleSearch as an example, we need to create a `tools` module under the `google` module, and create `tools/google_search.yaml`, the content is as follows.
|
||||
|
||||
```yaml
|
||||
identity: # Basic information of the tool
|
||||
name: google_search # Tool name, unique, no duplication with other tools
|
||||
author: Dify # Author
|
||||
label: # Label for frontend display
|
||||
en_US: GoogleSearch # English label
|
||||
zh_Hans: 谷歌搜索 # Chinese label
|
||||
description: # Description for frontend display
|
||||
human: # Introduction for frontend display, supports multiple languages
|
||||
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
|
||||
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. # Introduction passed to LLM, in order to make LLM better understand this tool, we suggest to write as detailed information about this tool as possible here, so that LLM can understand and use this tool
|
||||
parameters: # Parameter list
|
||||
- name: query # Parameter name
|
||||
type: string # Parameter type
|
||||
required: true # Required or not
|
||||
label: # Parameter label
|
||||
en_US: Query string # English label
|
||||
zh_Hans: 查询语句 # Chinese label
|
||||
human_description: # Introduction for frontend display, supports multiple languages
|
||||
en_US: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
llm_description: key words for searching # Introduction passed to LLM, similarly, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
|
||||
form: llm # Form type, llm means this parameter needs to be inferred by Agent, the frontend will not display this parameter
|
||||
- name: result_type
|
||||
type: select # Parameter type
|
||||
required: true
|
||||
options: # Drop-down box options
|
||||
- value: text
|
||||
label:
|
||||
en_US: text
|
||||
zh_Hans: 文本
|
||||
- value: link
|
||||
label:
|
||||
en_US: link
|
||||
zh_Hans: 链接
|
||||
default: link
|
||||
label:
|
||||
en_US: Result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: used for selecting the result type, text or link
|
||||
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
|
||||
form: form # Form type, form means this parameter needs to be filled in by the user on the frontend before the conversation starts
|
||||
|
||||
```
|
||||
|
||||
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
|
||||
- `parameters` Parameter list
|
||||
- `name` (Mandatory) Parameter name, must be unique and not duplicate with other parameters.
|
||||
- `type` (Mandatory) Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` five types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using the `secret-input` type
|
||||
- `label` (Mandatory) Parameter label, for frontend display
|
||||
- `form` (Mandatory) Form type, currently supports `llm`, `form` two types.
|
||||
- In an agent app, `llm` indicates that the parameter is inferred by the LLM itself, while `form` indicates that the parameter can be pre-set for the tool.
|
||||
- In a workflow app, both `llm` and `form` need to be filled out by the front end, but the parameters of `llm` will be used as input variables for the tool node.
|
||||
- `required` Indicates whether the parameter is required or not
|
||||
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
|
||||
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
|
||||
- `options` Parameter options
|
||||
- In `llm` mode, Dify will pass all options to LLM, LLM can infer based on these options
|
||||
- In `form` mode, when `type` is `select`, the frontend will display these options
|
||||
- `default` Default value
|
||||
- `min` Minimum value, can be set when the parameter type is `number`.
|
||||
- `max` Maximum value, can be set when the parameter type is `number`.
|
||||
- `placeholder` The prompt text for input boxes. It can be set when the form type is `form`, and the parameter type is `string`, `number`, or `secret-input`. It supports multiple languages.
|
||||
- `human_description` Introduction for frontend display, supports multiple languages
|
||||
- `llm_description` Introduction passed to LLM, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
|
||||
|
||||
|
||||
## 4. Add Tool Logic
|
||||
|
||||
After completing the tool configuration, we can start writing the tool code that defines how it is invoked.
|
||||
|
||||
Create `google_search.py` under the `google/tools` module, the content is as follows.
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class GoogleSearchTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
query = tool_parameters['query']
|
||||
result_type = tool_parameters['result_type']
|
||||
api_key = self.runtime.credentials['serpapi_api_key']
|
||||
# Search with serpapi
|
||||
result = SerpAPI(api_key).run(query, result_type=result_type)
|
||||
|
||||
if result_type == 'text':
|
||||
return self.create_text_message(text=result)
|
||||
return self.create_link_message(link=result)
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
The overall logic of the tool is in the `_invoke` method, this method accepts two parameters: `user_id` and `tool_parameters`, which represent the user ID and tool parameters respectively
|
||||
|
||||
### Return Data
|
||||
|
||||
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message. If you want to return multiple messages, you can use `[self.create_text_message('msg1'), self.create_text_message('msg2')]` to create a list of messages.
|
||||
|
||||
## 5. Add Provider Code
|
||||
|
||||
Finally, we need to create a provider class under the provider module to implement the provider's credential verification logic. If the credential verification fails, it will throw a `ToolProviderCredentialValidationError` exception.
|
||||
|
||||
Create `google.py` under the `google` module, the content is as follows.
|
||||
|
||||
```python
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
# 1. Here you need to instantiate a GoogleSearchTool with GoogleSearchTool(), it will automatically load the yaml configuration of GoogleSearchTool, but at this time it does not have credential information inside
|
||||
# 2. Then you need to use the fork_tool_runtime method to pass the current credential information to GoogleSearchTool
|
||||
# 3. Finally, invoke it, the parameters need to be passed according to the parameter rules configured in the yaml of GoogleSearchTool
|
||||
GoogleSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"query": "test",
|
||||
"result_type": "link"
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
```
|
||||
|
||||
## Completion
|
||||
|
||||
After the above steps are completed, we can see this tool on the frontend, and it can be used in the Agent.
|
||||
|
||||
Of course, because google_search needs a credential, before using it, you also need to input your credentials on the frontend.
|
||||
|
||||

|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 242 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 407 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 266 KiB |
@ -1,283 +0,0 @@
|
||||
# 高度なツール統合
|
||||
|
||||
このガイドを始める前に、Difyのツール統合プロセスの基本を理解していることを確認してください。簡単な概要については[クイック統合](./tool_scale_out.md)をご覧ください。
|
||||
|
||||
## ツールインターフェース
|
||||
|
||||
より複雑なツールを迅速に構築するのを支援するため、`Tool`クラスに一連のヘルパーメソッドを定義しています。
|
||||
|
||||
### メッセージの返却
|
||||
|
||||
Difyは`テキスト`、`リンク`、`画像`、`ファイルBLOB`、`JSON`などの様々なメッセージタイプをサポートしています。以下のインターフェースを通じて、異なるタイプのメッセージをLLMとユーザーに返すことができます。
|
||||
|
||||
注意:以下のインターフェースの一部のパラメータについては、後のセクションで説明します。
|
||||
|
||||
#### 画像URL
|
||||
画像のURLを渡すだけで、Difyが自動的に画像をダウンロードしてユーザーに返します。
|
||||
|
||||
```python
|
||||
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:param save_as: save as
|
||||
:return: the image message
|
||||
"""
|
||||
```
|
||||
|
||||
#### リンク
|
||||
リンクを返す必要がある場合は、以下のインターフェースを使用できます。
|
||||
|
||||
```python
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:param save_as: save as
|
||||
:return: the link message
|
||||
"""
|
||||
```
|
||||
|
||||
#### テキスト
|
||||
テキストメッセージを返す必要がある場合は、以下のインターフェースを使用できます。
|
||||
|
||||
```python
|
||||
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
||||
:param text: the text of the message
|
||||
:param save_as: save as
|
||||
:return: the text message
|
||||
"""
|
||||
```
|
||||
|
||||
#### ファイルBLOB
|
||||
画像、音声、動画、PPT、Word、Excelなどのファイルの生データを返す必要がある場合は、以下のインターフェースを使用できます。
|
||||
|
||||
- `blob` ファイルの生データ(bytes型)
|
||||
- `meta` ファイルのメタデータ。ファイルの種類が分かっている場合は、`mime_type`を渡すことをお勧めします。そうでない場合、Difyはデフォルトタイプとして`application/octet-stream`を使用します。
|
||||
|
||||
```python
|
||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:param meta: meta
|
||||
:param save_as: save as
|
||||
:return: the blob message
|
||||
"""
|
||||
```
|
||||
|
||||
#### JSON
|
||||
フォーマットされたJSONを返す必要がある場合は、以下のインターフェースを使用できます。これは通常、ワークフロー内のノード間のデータ伝送に使用されますが、エージェントモードでは、ほとんどの大規模言語モデルもJSONを読み取り、理解することができます。
|
||||
|
||||
- `object` Pythonの辞書オブジェクトで、自動的にJSONにシリアライズされます。
|
||||
|
||||
```python
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
```
|
||||
|
||||
### ショートカットツール
|
||||
|
||||
大規模モデルアプリケーションでは、以下の2つの一般的なニーズがあります:
|
||||
- まず長いテキストを事前に要約し、その要約内容をLLMに渡すことで、元のテキストが長すぎてLLMが処理できない問題を防ぐ
|
||||
- ツールが取得したコンテンツがリンクである場合、Webページ情報をクロールしてからLLMに返す必要がある
|
||||
|
||||
開発者がこれら2つのニーズを迅速に実装できるよう、以下の2つのショートカットツールを提供しています。
|
||||
|
||||
#### テキスト要約ツール
|
||||
|
||||
このツールはuser_idと要約するテキストを入力として受け取り、要約されたテキストを返します。Difyは現在のワークスペースのデフォルトモデルを使用して長文を要約します。
|
||||
|
||||
```python
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
"""
|
||||
summary the content
|
||||
|
||||
:param user_id: the user id
|
||||
:param content: the content
|
||||
:return: the summary
|
||||
"""
|
||||
```
|
||||
|
||||
#### Webページクローリングツール
|
||||
|
||||
このツールはクロールするWebページのリンクとユーザーエージェント(空でも可)を入力として受け取り、そのWebページの情報を含む文字列を返します。`user_agent`はオプションのパラメータで、ツールを識別するために使用できます。渡さない場合、Difyはデフォルトの`user_agent`を使用します。
|
||||
|
||||
```python
|
||||
def get_url(self, url: str, user_agent: str = None) -> str:
|
||||
"""
|
||||
get url from the crawled result
|
||||
"""
|
||||
```
|
||||
|
||||
### 変数プール
|
||||
|
||||
`Tool`内に変数プールを導入し、ツールの実行中に生成された変数やファイルなどを保存します。これらの変数は、ツールの実行中に他のツールが使用することができます。
|
||||
|
||||
次に、`DallE3`と`Vectorizer.AI`を例に、変数プールの使用方法を紹介します。
|
||||
|
||||
- `DallE3`は画像生成ツールで、テキストに基づいて画像を生成できます。ここでは、`DallE3`にカフェのロゴを生成させます。
|
||||
- `Vectorizer.AI`はベクター画像変換ツールで、画像をベクター画像に変換できるため、画像を無限に拡大しても品質が損なわれません。ここでは、`DallE3`が生成したPNGアイコンをベクター画像に変換し、デザイナーが実際に使用できるようにします。
|
||||
|
||||
#### DallE3
|
||||
まず、DallE3を使用します。画像を作成した後、その画像を変数プールに保存します。コードは以下の通りです:
|
||||
|
||||
```python
|
||||
from typing import Any, Dict, List, Union
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from base64 import b64decode
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
)
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
|
||||
# call openapi dalle3
|
||||
response = client.images.generate(
|
||||
prompt=prompt, model='dall-e-3',
|
||||
size='1024x1024', n=1, style='vivid', quality='standard',
|
||||
response_format='b64_json'
|
||||
)
|
||||
|
||||
result = []
|
||||
for image in response.data:
|
||||
# Save all images to the variable pool through the save_as parameter. The variable name is self.VARIABLE_KEY.IMAGE.value. If new images are generated later, they will overwrite the previous images.
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
ここでは画像の変数名として`self.VARIABLE_KEY.IMAGE.value`を使用していることに注意してください。開発者のツールが互いに連携できるよう、この`KEY`を定義しました。自由に使用することも、この`KEY`を使用しないこともできます。カスタムのKEYを渡すこともできます。
|
||||
|
||||
#### Vectorizer.AI
|
||||
次に、Vectorizer.AIを使用して、DallE3が生成したPNGアイコンをベクター画像に変換します。ここで定義した関数を見てみましょう。コードは以下の通りです:
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any])
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool
|
||||
"""
|
||||
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParameter]:
|
||||
"""
|
||||
Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list
|
||||
"""
|
||||
|
||||
|
||||
def is_tool_available(self) -> bool:
|
||||
"""
|
||||
Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here
|
||||
"""
|
||||
```
|
||||
|
||||
次に、これら3つの関数を実装します:
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any])
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key_name = self.runtime.credentials.get('api_key_name', None)
|
||||
api_key_value = self.runtime.credentials.get('api_key_value', None)
|
||||
|
||||
if not api_key_name or not api_key_value:
|
||||
raise ToolProviderCredentialValidationError('Please input api key name and value')
|
||||
|
||||
# Get image_id, the definition of image_id can be found in get_runtime_parameters
|
||||
image_id = tool_parameters.get('image_id', '')
|
||||
if not image_id:
|
||||
return self.create_text_message('Please input image id')
|
||||
|
||||
# Get the image generated by DallE from the variable pool
|
||||
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
|
||||
if not image_binary:
|
||||
return self.create_text_message('Image not found, please request user to generate image firstly.')
|
||||
|
||||
# Generate vector image
|
||||
response = post(
|
||||
'https://vectorizer.ai/api/v1/vectorize',
|
||||
files={ 'image': image_binary },
|
||||
data={ 'mode': 'test' },
|
||||
auth=(api_key_name, api_key_value),
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
return [
|
||||
self.create_text_message('the vectorized svg is saved as an image.'),
|
||||
self.create_blob_message(blob=response.content,
|
||||
meta={'mime_type': 'image/svg+xml'})
|
||||
]
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParameter]:
|
||||
"""
|
||||
override the runtime parameters
|
||||
"""
|
||||
# Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml.
|
||||
return [
|
||||
ToolParameter.get_simple_instance(
|
||||
name='image_id',
|
||||
llm_description=f'the image id that you want to vectorize, \
|
||||
and the image id should be specified in \
|
||||
{[i.name for i in self.list_default_image_variables()]}',
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=True,
|
||||
options=[i.name for i in self.list_default_image_variables()]
|
||||
)
|
||||
]
|
||||
|
||||
def is_tool_available(self) -> bool:
|
||||
# Only when there are images in the variable pool, the LLM needs to use this tool
|
||||
return len(self.list_default_image_variables()) > 0
|
||||
```
|
||||
|
||||
ここで注目すべきは、実際には`image_id`を使用していないことです。このツールを呼び出す際には、デフォルトの変数プールに必ず画像があると仮定し、直接`image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)`を使用して画像を取得しています。モデルの能力が弱い場合、開発者にもこの方法を推奨します。これにより、エラー許容度を効果的に向上させ、モデルが誤ったパラメータを渡すのを防ぐことができます。
|
||||
@ -1,240 +0,0 @@
|
||||
# ツールの迅速な統合
|
||||
|
||||
ここでは、GoogleSearchを例にツールを迅速に統合する方法を紹介します。
|
||||
|
||||
## 1. ツールプロバイダーのyamlを準備する
|
||||
|
||||
### 概要
|
||||
|
||||
このyamlファイルには、プロバイダー名、アイコン、作者などの詳細情報が含まれ、フロントエンドでの柔軟な表示を可能にします。
|
||||
|
||||
### 例
|
||||
|
||||
`core/tools/provider/builtin`の下に`google`モジュール(フォルダ)を作成し、`google.yaml`を作成します。名前はモジュール名と一致している必要があります。
|
||||
|
||||
以降、このツールに関するすべての操作はこのモジュール内で行います。
|
||||
|
||||
```yaml
|
||||
identity: # ツールプロバイダーの基本情報
|
||||
author: Dify # 作者
|
||||
name: google # 名前(一意、他のプロバイダーと重複不可)
|
||||
label: # フロントエンド表示用のラベル
|
||||
en_US: Google # 英語ラベル
|
||||
zh_Hans: Google # 中国語ラベル
|
||||
description: # フロントエンド表示用の説明
|
||||
en_US: Google # 英語説明
|
||||
zh_Hans: Google # 中国語説明
|
||||
icon: icon.svg # アイコン(現在のモジュールの_assetsフォルダに配置)
|
||||
tags: # タグ(フロントエンド表示用)
|
||||
- search
|
||||
```
|
||||
|
||||
- `identity`フィールドは必須で、ツールプロバイダーの基本情報(作者、名前、ラベル、説明、アイコンなど)が含まれます。
|
||||
- アイコンは現在のモジュールの`_assets`フォルダに配置する必要があります。[こちら](../../provider/builtin/google/_assets/icon.svg)を参照してください。
|
||||
- タグはフロントエンドでの表示に使用され、ユーザーがこのツールプロバイダーを素早く見つけるのに役立ちます。現在サポートされているすべてのタグは以下の通りです:
|
||||
```python
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = 'search'
|
||||
IMAGE = 'image'
|
||||
VIDEOS = 'videos'
|
||||
WEATHER = 'weather'
|
||||
FINANCE = 'finance'
|
||||
DESIGN = 'design'
|
||||
TRAVEL = 'travel'
|
||||
SOCIAL = 'social'
|
||||
NEWS = 'news'
|
||||
MEDICAL = 'medical'
|
||||
PRODUCTIVITY = 'productivity'
|
||||
EDUCATION = 'education'
|
||||
BUSINESS = 'business'
|
||||
ENTERTAINMENT = 'entertainment'
|
||||
UTILITIES = 'utilities'
|
||||
OTHER = 'other'
|
||||
```
|
||||
|
||||
## 2. プロバイダーの認証情報を準備する
|
||||
|
||||
GoogleはSerpApiが提供するAPIを使用するサードパーティツールであり、SerpApiを使用するにはAPI Keyが必要です。つまり、このツールを使用するには認証情報が必要です。一方、`wikipedia`のようなツールでは認証情報フィールドを記入する必要はありません。[こちら](../../provider/builtin/wikipedia/wikipedia.yaml)を参照してください。
|
||||
|
||||
認証情報フィールドを設定すると、以下のようになります:
|
||||
|
||||
```yaml
|
||||
identity:
|
||||
author: Dify
|
||||
name: google
|
||||
label:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
description:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
icon: icon.svg
|
||||
credentials_for_provider: # 認証情報フィールド
|
||||
serpapi_api_key: # 認証情報フィールド名
|
||||
type: secret-input # 認証情報フィールドタイプ
|
||||
required: true # 必須かどうか
|
||||
label: # 認証情報フィールドラベル
|
||||
en_US: SerpApi API key # 英語ラベル
|
||||
zh_Hans: SerpApi API key # 中国語ラベル
|
||||
placeholder: # 認証情報フィールドプレースホルダー
|
||||
en_US: Please input your SerpApi API key # 英語プレースホルダー
|
||||
zh_Hans: 请输入你的 SerpApi API key # 中国語プレースホルダー
|
||||
help: # 認証情報フィールドヘルプテキスト
|
||||
en_US: Get your SerpApi API key from SerpApi # 英語ヘルプテキスト
|
||||
zh_Hans: 从 SerpApi 获取您的 SerpApi API key # 中国語ヘルプテキスト
|
||||
url: https://serpapi.com/manage-api-key # 認証情報フィールドヘルプリンク
|
||||
```
|
||||
|
||||
- `type`:認証情報フィールドタイプ。現在、`secret-input`、`text-input`、`select`の3種類をサポートしており、それぞれパスワード入力ボックス、テキスト入力ボックス、ドロップダウンボックスに対応します。`secret-input`の場合、フロントエンドで入力内容が隠され、バックエンドで入力内容が暗号化されます。
|
||||
|
||||
## 3. ツールのyamlを準備する
|
||||
|
||||
1つのプロバイダーの下に複数のツールを持つことができ、各ツールにはyamlファイルが必要です。このファイルにはツールの基本情報、パラメータ、出力などが含まれます。
|
||||
|
||||
引き続きGoogleSearchを例に、`google`モジュールの下に`tools`モジュールを作成し、`tools/google_search.yaml`を作成します。内容は以下の通りです:
|
||||
|
||||
```yaml
|
||||
identity: # ツールの基本情報
|
||||
name: google_search # ツール名(一意、他のツールと重複不可)
|
||||
author: Dify # 作者
|
||||
label: # フロントエンド表示用のラベル
|
||||
en_US: GoogleSearch # 英語ラベル
|
||||
zh_Hans: 谷歌搜索 # 中国語ラベル
|
||||
description: # フロントエンド表示用の説明
|
||||
human: # フロントエンド表示用の紹介(多言語対応)
|
||||
en_US: A tool for performing a Google SERP search and extracting snippets and webpages. Input should be a search query.
|
||||
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
|
||||
llm: A tool for performing a Google SERP search and extracting snippets and webpages. Input should be a search query. # LLMに渡す紹介文。LLMがこのツールをより理解できるよう、できるだけ詳細な情報を記述することをお勧めします。
|
||||
parameters: # パラメータリスト
|
||||
- name: query # パラメータ名
|
||||
type: string # パラメータタイプ
|
||||
required: true # 必須かどうか
|
||||
label: # パラメータラベル
|
||||
en_US: Query string # 英語ラベル
|
||||
zh_Hans: 查询语句 # 中国語ラベル
|
||||
human_description: # フロントエンド表示用の紹介(多言語対応)
|
||||
en_US: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
llm_description: key words for searching # LLMに渡す紹介文。LLMがこのパラメータをより理解できるよう、できるだけ詳細な情報を記述することをお勧めします。
|
||||
form: llm # フォームタイプ。llmはこのパラメータがAgentによって推論される必要があることを意味し、フロントエンドではこのパラメータは表示されません。
|
||||
- name: result_type
|
||||
type: select # パラメータタイプ
|
||||
required: true
|
||||
options: # ドロップダウンボックスのオプション
|
||||
- value: text
|
||||
label:
|
||||
en_US: text
|
||||
zh_Hans: 文本
|
||||
- value: link
|
||||
label:
|
||||
en_US: link
|
||||
zh_Hans: 链接
|
||||
default: link
|
||||
label:
|
||||
en_US: Result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: used for selecting the result type, text or link
|
||||
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
|
||||
form: form # フォームタイプ。formはこのパラメータが対話開始前にフロントエンドでユーザーによって入力される必要があることを意味します。
|
||||
```
|
||||
|
||||
- `identity`フィールドは必須で、ツールの基本情報(名前、作者、ラベル、説明など)が含まれます。
|
||||
- `parameters` パラメータリスト
|
||||
- `name`(必須)パラメータ名。一意で、他のパラメータと重複しないようにしてください。
|
||||
- `type`(必須)パラメータタイプ。現在、`string`、`number`、`boolean`、`select`、`secret-input`の5種類をサポートしており、それぞれ文字列、数値、ブール値、ドロップダウンボックス、暗号化入力ボックスに対応します。機密情報には`secret-input`タイプの使用をお勧めします。
|
||||
- `label`(必須)パラメータラベル。フロントエンド表示用です。
|
||||
- `form`(必須)フォームタイプ。現在、`llm`と`form`の2種類をサポートしています。
|
||||
- エージェントアプリケーションでは、`llm`はこのパラメータがLLM自身によって推論されることを示し、`form`はこのツールを使用するために事前に設定できるパラメータであることを示します。
|
||||
- ワークフローアプリケーションでは、`llm`と`form`の両方がフロントエンドで入力する必要がありますが、`llm`のパラメータはツールノードの入力変数として使用されます。
|
||||
- `required` パラメータが必須かどうかを示します。
|
||||
- `llm`モードでは、パラメータが必須の場合、Agentはこのパラメータを推論する必要があります。
|
||||
- `form`モードでは、パラメータが必須の場合、ユーザーは対話開始前にフロントエンドでこのパラメータを入力する必要があります。
|
||||
- `options` パラメータオプション
|
||||
- `llm`モードでは、DifyはすべてのオプションをLLMに渡し、LLMはこれらのオプションに基づいて推論できます。
|
||||
- `form`モードで、`type`が`select`の場合、フロントエンドはこれらのオプションを表示します。
|
||||
- `default` デフォルト値
|
||||
- `min` 最小値。パラメータタイプが`number`の場合に設定できます。
|
||||
- `max` 最大値。パラメータタイプが`number`の場合に設定できます。
|
||||
- `human_description` フロントエンド表示用の紹介。多言語対応です。
|
||||
- `placeholder` 入力ボックスのプロンプトテキスト。フォームタイプが`form`で、パラメータタイプが`string`、`number`、`secret-input`の場合に設定できます。多言語対応です。
|
||||
- `llm_description` LLMに渡す紹介文。LLMがこのパラメータをより理解できるよう、できるだけ詳細な情報を記述することをお勧めします。
|
||||
|
||||
## 4. ツールコードを準備する
|
||||
|
||||
ツールの設定が完了したら、ツールのロジックを実装するコードを作成します。
|
||||
|
||||
`google/tools`モジュールの下に`google_search.py`を作成し、内容は以下の通りです:
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class GoogleSearchTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
ツールを呼び出す
|
||||
"""
|
||||
query = tool_parameters['query']
|
||||
result_type = tool_parameters['result_type']
|
||||
api_key = self.runtime.credentials['serpapi_api_key']
|
||||
result = SerpAPI(api_key).run(query, result_type=result_type)
|
||||
|
||||
if result_type == 'text':
|
||||
return self.create_text_message(text=result)
|
||||
return self.create_link_message(link=result)
|
||||
```
|
||||
|
||||
### パラメータ
|
||||
ツールの全体的なロジックは`_invoke`メソッドにあります。このメソッドは2つのパラメータ(`user_id`とtool_parameters`)を受け取り、それぞれユーザーIDとツールパラメータを表します。
|
||||
|
||||
### 戻り値
|
||||
ツールの戻り値として、1つのメッセージまたは複数のメッセージを選択できます。ここでは1つのメッセージを返しています。`create_text_message`と`create_link_message`を使用して、テキストメッセージまたはリンクメッセージを作成できます。複数のメッセージを返す場合は、リストを構築できます(例:`[self.create_text_message('msg1'), self.create_text_message('msg2')]`)。
|
||||
|
||||
## 5. プロバイダーコードを準備する
|
||||
|
||||
最後に、プロバイダーモジュールの下にプロバイダークラスを作成し、プロバイダーの認証情報検証ロジックを実装する必要があります。認証情報の検証が失敗した場合、`ToolProviderCredentialValidationError`例外が発生します。
|
||||
|
||||
`google`モジュールの下に`google.py`を作成し、内容は以下の通りです:
|
||||
|
||||
```python
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
# 1. ここでGoogleSearchTool()を使ってGoogleSearchToolをインスタンス化する必要があります。これによりGoogleSearchToolのyaml設定が自動的に読み込まれますが、この時点では認証情報は含まれていません
|
||||
# 2. 次に、fork_tool_runtimeメソッドを使用して、現在の認証情報をGoogleSearchToolに渡す必要があります
|
||||
# 3. 最後に、invokeを呼び出します。パラメータはGoogleSearchToolのyamlで設定されたパラメータルールに従って渡す必要があります
|
||||
GoogleSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"query": "test",
|
||||
"result_type": "link"
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
```
|
||||
|
||||
## 完了
|
||||
|
||||
以上のステップが完了すると、このツールをフロントエンドで確認し、Agentで使用することができるようになります。
|
||||
|
||||
もちろん、google_searchには認証情報が必要なため、使用する前にフロントエンドで認証情報を入力する必要があります。
|
||||
|
||||

|
||||
@ -1,283 +0,0 @@
|
||||
# 高级接入Tool
|
||||
|
||||
在开始高级接入之前,请确保你已经阅读过[快速接入](./tool_scale_out.md),并对Dify的工具接入流程有了基本的了解。
|
||||
|
||||
## 工具接口
|
||||
|
||||
我们在`Tool`类中定义了一系列快捷方法,用于帮助开发者快速构较为复杂的工具
|
||||
|
||||
### 消息返回
|
||||
|
||||
Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型,你可以通过以下几个接口返回不同类型的消息给LLM和用户。
|
||||
|
||||
注意,在下面的接口中的部分参数将在后面的章节中介绍。
|
||||
|
||||
#### 图片URL
|
||||
只需要传递图片的URL即可,Dify会自动下载图片并返回给用户。
|
||||
|
||||
```python
|
||||
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:param save_as: save as
|
||||
:return: the image message
|
||||
"""
|
||||
```
|
||||
|
||||
#### 链接
|
||||
如果你需要返回一个链接,可以使用以下接口。
|
||||
|
||||
```python
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:param save_as: save as
|
||||
:return: the link message
|
||||
"""
|
||||
```
|
||||
|
||||
#### 文本
|
||||
如果你需要返回一个文本消息,可以使用以下接口。
|
||||
|
||||
```python
|
||||
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
||||
:param text: the text of the message
|
||||
:param save_as: save as
|
||||
:return: the text message
|
||||
"""
|
||||
```
|
||||
|
||||
#### 文件BLOB
|
||||
如果你需要返回文件的原始数据,如图片、音频、视频、PPT、Word、Excel等,可以使用以下接口。
|
||||
|
||||
- `blob` 文件的原始数据,bytes类型
|
||||
- `meta` 文件的元数据,如果你知道该文件的类型,最好传递一个`mime_type`,否则Dify将使用`application/octet-stream`作为默认类型
|
||||
|
||||
```python
|
||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:param meta: meta
|
||||
:param save_as: save as
|
||||
:return: the blob message
|
||||
"""
|
||||
```
|
||||
|
||||
#### JSON
|
||||
如果你需要返回一个格式化的JSON,可以使用以下接口。这通常用于workflow中的节点间的数据传递,当然agent模式中,大部分大模型也都能够阅读和理解JSON。
|
||||
|
||||
- `object` 一个Python的字典对象,会被自动序列化为JSON
|
||||
|
||||
```python
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
```
|
||||
|
||||
### 快捷工具
|
||||
|
||||
在大模型应用中,我们有两种常见的需求:
|
||||
- 先将很长的文本进行提前总结,然后再将总结内容传递给LLM,以防止原文本过长导致LLM无法处理
|
||||
- 工具获取到的内容是一个链接,需要爬取网页信息后再返回给LLM
|
||||
|
||||
为了帮助开发者快速实现这两种需求,我们提供了以下两个快捷工具。
|
||||
|
||||
#### 文本总结工具
|
||||
|
||||
该工具需要传入user_id和需要进行总结的文本,返回一个总结后的文本,Dify会使用当前工作空间的默认模型对长文本进行总结。
|
||||
|
||||
```python
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
"""
|
||||
summary the content
|
||||
|
||||
:param user_id: the user id
|
||||
:param content: the content
|
||||
:return: the summary
|
||||
"""
|
||||
```
|
||||
|
||||
#### 网页爬取工具
|
||||
|
||||
该工具需要传入需要爬取的网页链接和一个user_agent(可为空),返回一个包含该网页信息的字符串,其中`user_agent`是可选参数,可以用来识别工具,如果不传递,Dify将使用默认的`user_agent`。
|
||||
|
||||
```python
|
||||
def get_url(self, url: str, user_agent: str = None) -> str:
|
||||
"""
|
||||
get url from the crawled result
|
||||
"""
|
||||
```
|
||||
|
||||
### 变量池
|
||||
|
||||
我们在`Tool`中引入了一个变量池,用于存储工具运行过程中产生的变量、文件等,这些变量可以在工具运行过程中被其他工具使用。
|
||||
|
||||
下面,我们以`DallE3`和`Vectorizer.AI`为例,介绍如何使用变量池。
|
||||
|
||||
- `DallE3`是一个图片生成工具,它可以根据文本生成图片,在这里,我们将让`DallE3`生成一个咖啡厅的Logo
|
||||
- `Vectorizer.AI`是一个矢量图转换工具,它可以将图片转换为矢量图,使得图片可以无限放大而不失真,在这里,我们将`DallE3`生成的PNG图标转换为矢量图,从而可以真正被设计师使用。
|
||||
|
||||
#### DallE3
|
||||
首先我们使用DallE3,在创建完图片以后,我们将图片保存到变量池中,代码如下
|
||||
|
||||
```python
|
||||
from typing import Any, Dict, List, Union
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from base64 import b64decode
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
)
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
|
||||
# call openapi dalle3
|
||||
response = client.images.generate(
|
||||
prompt=prompt, model='dall-e-3',
|
||||
size='1024x1024', n=1, style='vivid', quality='standard',
|
||||
response_format='b64_json'
|
||||
)
|
||||
|
||||
result = []
|
||||
for image in response.data:
|
||||
# 将所有图片通过save_as参数保存到变量池中,变量名为self.VARIABLE_KEY.IMAGE.value,如果如果后续有新的图片生成,那么将会覆盖之前的图片
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
我们可以注意到这里我们使用了`self.VARIABLE_KEY.IMAGE.value`作为图片的变量名,为了便于开发者们的工具能够互相配合,我们定义了这个`KEY`,大家可以自由使用,也可以不使用这个`KEY`,传递一个自定义的KEY也是可以的。
|
||||
|
||||
#### Vectorizer.AI
|
||||
接下来我们使用Vectorizer.AI,将DallE3生成的PNG图标转换为矢量图,我们先来过一遍我们在这里定义的函数,代码如下
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
工具调用,图片变量名需要从这里传递进来,从而我们就可以从变量池中获取到图片
|
||||
"""
|
||||
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParameter]:
|
||||
"""
|
||||
重写工具参数列表,我们可以根据当前变量池里的实际情况来动态生成参数列表,从而LLM可以根据参数列表来生成表单
|
||||
"""
|
||||
|
||||
|
||||
def is_tool_available(self) -> bool:
|
||||
"""
|
||||
当前工具是否可用,如果当前变量池中没有图片,那么我们就不需要展示这个工具,这里返回False即可
|
||||
"""
|
||||
```
|
||||
|
||||
接下来我们来实现这三个函数
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key_name = self.runtime.credentials.get('api_key_name', None)
|
||||
api_key_value = self.runtime.credentials.get('api_key_value', None)
|
||||
|
||||
if not api_key_name or not api_key_value:
|
||||
raise ToolProviderCredentialValidationError('Please input api key name and value')
|
||||
|
||||
# 获取image_id,image_id的定义可以在get_runtime_parameters中找到
|
||||
image_id = tool_parameters.get('image_id', '')
|
||||
if not image_id:
|
||||
return self.create_text_message('Please input image id')
|
||||
|
||||
# 从变量池中获取到之前DallE生成的图片
|
||||
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
|
||||
if not image_binary:
|
||||
return self.create_text_message('Image not found, please request user to generate image firstly.')
|
||||
|
||||
# 生成矢量图
|
||||
response = post(
|
||||
'https://vectorizer.ai/api/v1/vectorize',
|
||||
files={ 'image': image_binary },
|
||||
data={ 'mode': 'test' },
|
||||
auth=(api_key_name, api_key_value),
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
return [
|
||||
self.create_text_message('the vectorized svg is saved as an image.'),
|
||||
self.create_blob_message(blob=response.content,
|
||||
meta={'mime_type': 'image/svg+xml'})
|
||||
]
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParameter]:
|
||||
"""
|
||||
override the runtime parameters
|
||||
"""
|
||||
# 这里,我们重写了工具参数列表,定义了image_id,并设置了它的选项列表为当前变量池中的所有图片,这里的配置与yaml中的配置是一致的
|
||||
return [
|
||||
ToolParameter.get_simple_instance(
|
||||
name='image_id',
|
||||
llm_description=f'the image id that you want to vectorize, \
|
||||
and the image id should be specified in \
|
||||
{[i.name for i in self.list_default_image_variables()]}',
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=True,
|
||||
options=[i.name for i in self.list_default_image_variables()]
|
||||
)
|
||||
]
|
||||
|
||||
def is_tool_available(self) -> bool:
|
||||
# 只有当变量池中有图片时,LLM才需要使用这个工具
|
||||
return len(self.list_default_image_variables()) > 0
|
||||
```
|
||||
|
||||
可以注意到的是,我们这里其实并没有使用到`image_id`,我们已经假设了调用这个工具的时候一定有一张图片在默认的变量池中,所以直接使用了`image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)`来获取图片,在模型能力较弱的情况下,我们建议开发者们也这样做,可以有效提升容错率,避免模型传递错误的参数。
|
||||
@ -1,237 +0,0 @@
|
||||
# 快速接入Tool
|
||||
|
||||
这里我们以GoogleSearch为例,介绍如何快速接入一个工具。
|
||||
|
||||
## 1. 准备工具供应商yaml
|
||||
|
||||
### 介绍
|
||||
这个yaml将包含工具供应商的信息,包括供应商名称、图标、作者等详细信息,以帮助前端灵活展示。
|
||||
|
||||
### 示例
|
||||
|
||||
我们需要在 `core/tools/provider/builtin`下创建一个`google`模块(文件夹),并创建`google.yaml`,名称必须与模块名称一致。
|
||||
|
||||
后续,我们关于这个工具的所有操作都将在这个模块下进行。
|
||||
|
||||
```yaml
|
||||
identity: # 工具供应商的基本信息
|
||||
author: Dify # 作者
|
||||
name: google # 名称,唯一,不允许和其他供应商重名
|
||||
label: # 标签,用于前端展示
|
||||
en_US: Google # 英文标签
|
||||
zh_Hans: Google # 中文标签
|
||||
description: # 描述,用于前端展示
|
||||
en_US: Google # 英文描述
|
||||
zh_Hans: Google # 中文描述
|
||||
icon: icon.svg # 图标,需要放置在当前模块的_assets文件夹下
|
||||
tags: # 标签,用于前端展示
|
||||
- search
|
||||
|
||||
```
|
||||
- `identity` 字段是必须的,它包含了工具供应商的基本信息,包括作者、名称、标签、描述、图标等
|
||||
- 图标需要放置在当前模块的`_assets`文件夹下,可以参考[这里](../../provider/builtin/google/_assets/icon.svg)。
|
||||
- 标签用于前端展示,可以帮助用户快速找到这个工具供应商,下面列出了目前所支持的所有标签
|
||||
```python
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = 'search'
|
||||
IMAGE = 'image'
|
||||
VIDEOS = 'videos'
|
||||
WEATHER = 'weather'
|
||||
FINANCE = 'finance'
|
||||
DESIGN = 'design'
|
||||
TRAVEL = 'travel'
|
||||
SOCIAL = 'social'
|
||||
NEWS = 'news'
|
||||
MEDICAL = 'medical'
|
||||
PRODUCTIVITY = 'productivity'
|
||||
EDUCATION = 'education'
|
||||
BUSINESS = 'business'
|
||||
ENTERTAINMENT = 'entertainment'
|
||||
UTILITIES = 'utilities'
|
||||
OTHER = 'other'
|
||||
```
|
||||
|
||||
## 2. 准备供应商凭据
|
||||
|
||||
Google作为一个第三方工具,使用了SerpApi提供的API,而SerpApi需要一个API Key才能使用,那么就意味着这个工具需要一个凭据才可以使用,而像`wikipedia`这样的工具,就不需要填写凭据字段,可以参考[这里](../../provider/builtin/wikipedia/wikipedia.yaml)。
|
||||
|
||||
配置好凭据字段后效果如下:
|
||||
```yaml
|
||||
identity:
|
||||
author: Dify
|
||||
name: google
|
||||
label:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
description:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
icon: icon.svg
|
||||
credentials_for_provider: # 凭据字段
|
||||
serpapi_api_key: # 凭据字段名称
|
||||
type: secret-input # 凭据字段类型
|
||||
required: true # 是否必填
|
||||
label: # 凭据字段标签
|
||||
en_US: SerpApi API key # 英文标签
|
||||
zh_Hans: SerpApi API key # 中文标签
|
||||
placeholder: # 凭据字段占位符
|
||||
en_US: Please input your SerpApi API key # 英文占位符
|
||||
zh_Hans: 请输入你的 SerpApi API key # 中文占位符
|
||||
help: # 凭据字段帮助文本
|
||||
en_US: Get your SerpApi API key from SerpApi # 英文帮助文本
|
||||
zh_Hans: 从 SerpApi 获取您的 SerpApi API key # 中文帮助文本
|
||||
url: https://serpapi.com/manage-api-key # 凭据字段帮助链接
|
||||
|
||||
```
|
||||
|
||||
- `type`:凭据字段类型,目前支持`secret-input`、`text-input`、`select` 三种类型,分别对应密码输入框、文本输入框、下拉框,如果为`secret-input`,则会在前端隐藏输入内容,并且后端会对输入内容进行加密。
|
||||
|
||||
## 3. 准备工具yaml
|
||||
一个供应商底下可以有多个工具,每个工具都需要一个yaml文件来描述,这个文件包含了工具的基本信息、参数、输出等。
|
||||
|
||||
仍然以GoogleSearch为例,我们需要在`google`模块下创建一个`tools`模块,并创建`tools/google_search.yaml`,内容如下。
|
||||
|
||||
```yaml
|
||||
identity: # 工具的基本信息
|
||||
name: google_search # 工具名称,唯一,不允许和其他工具重名
|
||||
author: Dify # 作者
|
||||
label: # 标签,用于前端展示
|
||||
en_US: GoogleSearch # 英文标签
|
||||
zh_Hans: 谷歌搜索 # 中文标签
|
||||
description: # 描述,用于前端展示
|
||||
human: # 用于前端展示的介绍,支持多语言
|
||||
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
|
||||
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. # 传递给LLM的介绍,为了使得LLM更好理解这个工具,我们建议在这里写上关于这个工具尽可能详细的信息,让LLM能够理解并使用这个工具
|
||||
parameters: # 参数列表
|
||||
- name: query # 参数名称
|
||||
type: string # 参数类型
|
||||
required: true # 是否必填
|
||||
label: # 参数标签
|
||||
en_US: Query string # 英文标签
|
||||
zh_Hans: 查询语句 # 中文标签
|
||||
human_description: # 用于前端展示的介绍,支持多语言
|
||||
en_US: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
llm_description: key words for searching # 传递给LLM的介绍,同上,为了使得LLM更好理解这个参数,我们建议在这里写上关于这个参数尽可能详细的信息,让LLM能够理解这个参数
|
||||
form: llm # 表单类型,llm表示这个参数需要由Agent自行推理出来,前端将不会展示这个参数
|
||||
- name: result_type
|
||||
type: select # 参数类型
|
||||
required: true
|
||||
options: # 下拉框选项
|
||||
- value: text
|
||||
label:
|
||||
en_US: text
|
||||
zh_Hans: 文本
|
||||
- value: link
|
||||
label:
|
||||
en_US: link
|
||||
zh_Hans: 链接
|
||||
default: link
|
||||
label:
|
||||
en_US: Result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: used for selecting the result type, text or link
|
||||
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
|
||||
form: form # 表单类型,form表示这个参数需要由用户在对话开始前在前端填写
|
||||
|
||||
```
|
||||
|
||||
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
|
||||
- `parameters` 参数列表
|
||||
- `name` (必填)参数名称,唯一,不允许和其他参数重名
|
||||
- `type` (必填)参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
|
||||
- `label`(必填)参数标签,用于前端展示
|
||||
- `form` (必填)表单类型,目前支持`llm`、`form`两种类型
|
||||
- 在Agent应用中,`llm`表示该参数LLM自行推理,`form`表示要使用该工具可提前设定的参数
|
||||
- 在workflow应用中,`llm`和`form`均需要前端填写,但`llm`的参数会做为工具节点的输入变量
|
||||
- `required` 是否必填
|
||||
- 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数
|
||||
- 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
|
||||
- `options` 参数选项
|
||||
- 在`llm`模式下,Dify会将所有选项传递给LLM,LLM可以根据这些选项进行推理
|
||||
- 在`form`模式下,`type`为`select`时,前端会展示这些选项
|
||||
- `default` 默认值
|
||||
- `min` 最小值,当参数类型为`number`时可以设定
|
||||
- `max` 最大值,当参数类型为`number`时可以设定
|
||||
- `human_description` 用于前端展示的介绍,支持多语言
|
||||
- `placeholder` 字段输入框的提示文字,在表单类型为`form`,参数类型为`string`、`number`、`secret-input`时,可以设定,支持多语言
|
||||
- `llm_description` 传递给LLM的介绍,为了使得LLM更好理解这个参数,我们建议在这里写上关于这个参数尽可能详细的信息,让LLM能够理解这个参数
|
||||
|
||||
|
||||
## 4. 准备工具代码
|
||||
当完成工具的配置以后,我们就可以开始编写工具代码了,主要用于实现工具的逻辑。
|
||||
|
||||
在`google/tools`模块下创建`google_search.py`,内容如下。
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class GoogleSearchTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
query = tool_parameters['query']
|
||||
result_type = tool_parameters['result_type']
|
||||
api_key = self.runtime.credentials['serpapi_api_key']
|
||||
result = SerpAPI(api_key).run(query, result_type=result_type)
|
||||
|
||||
if result_type == 'text':
|
||||
return self.create_text_message(text=result)
|
||||
return self.create_link_message(link=result)
|
||||
```
|
||||
|
||||
### 参数
|
||||
工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id`和`tool_parameters`,分别表示用户ID和工具参数
|
||||
|
||||
### 返回数据
|
||||
在工具返回时,你可以选择返回一条消息或者多个消息,这里我们返回一条消息,使用`create_text_message`和`create_link_message`可以创建一条文本消息或者一条链接消息。如需返回多条消息,可以使用列表构建,例如`[self.create_text_message('msg1'), self.create_text_message('msg2')]`
|
||||
|
||||
## 5. 准备供应商代码
|
||||
最后,我们需要在供应商模块下创建一个供应商类,用于实现供应商的凭据验证逻辑,如果凭据验证失败,将会抛出`ToolProviderCredentialValidationError`异常。
|
||||
|
||||
在`google`模块下创建`google.py`,内容如下。
|
||||
|
||||
```python
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
# 1. 此处需要使用GoogleSearchTool()实例化一个GoogleSearchTool,它会自动加载GoogleSearchTool的yaml配置,但是此时它内部没有凭据信息
|
||||
# 2. 随后需要使用fork_tool_runtime方法,将当前的凭据信息传递给GoogleSearchTool
|
||||
# 3. 最后invoke即可,参数需要根据GoogleSearchTool的yaml中配置的参数规则进行传递
|
||||
GoogleSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"query": "test",
|
||||
"result_type": "link"
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
```
|
||||
|
||||
## 完成
|
||||
当上述步骤完成以后,我们就可以在前端看到这个工具了,并且可以在Agent中使用这个工具。
|
||||
|
||||
当然,因为google_search需要一个凭据,在使用之前,还需要在前端配置它的凭据。
|
||||
|
||||

|
||||
@ -63,11 +63,18 @@ class ToolFileManager:
|
||||
conversation_id: Optional[str],
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: Optional[str] = None,
|
||||
) -> ToolFile:
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
unique_filename = f"{unique_name}{extension}"
|
||||
# default just as before
|
||||
present_filename = unique_filename
|
||||
if filename is not None:
|
||||
has_extension = len(filename.split(".")) > 1
|
||||
# Add extension flexibly
|
||||
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
tool_file = ToolFile(
|
||||
@ -76,7 +83,7 @@ class ToolFileManager:
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
name=filename,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
)
|
||||
|
||||
|
||||
@ -765,17 +765,22 @@ class ToolManager:
|
||||
|
||||
@classmethod
|
||||
def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
|
||||
return (
|
||||
dify_config.CONSOLE_API_URL
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
+ provider_id
|
||||
+ "/icon"
|
||||
return str(
|
||||
URL(dify_config.CONSOLE_API_URL or "/")
|
||||
/ "console"
|
||||
/ "api"
|
||||
/ "workspaces"
|
||||
/ "current"
|
||||
/ "tool-provider"
|
||||
/ "builtin"
|
||||
/ provider_id
|
||||
/ "icon"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
|
||||
return str(
|
||||
URL(dify_config.CONSOLE_API_URL)
|
||||
URL(dify_config.CONSOLE_API_URL or "/")
|
||||
/ "console"
|
||||
/ "api"
|
||||
/ "workspaces"
|
||||
|
||||
@ -59,6 +59,8 @@ class ToolFileMessageTransformer:
|
||||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||
# get filename from meta
|
||||
filename = meta.get("file_name", None)
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
@ -72,6 +74,7 @@ class ToolFileMessageTransformer:
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message.blob,
|
||||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user