Compare commits

...

109 Commits

Author SHA1 Message Date
ac80c04bd3 chore: bump version to 1.1.0 (#16128)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-18 20:01:24 +08:00
fa9b767bf2 fix chatflow metadata field name (#16130) 2025-03-18 19:40:42 +08:00
abeaea4f79 Support knowledge metadata filter (#15982) 2025-03-18 16:42:19 +08:00
b65f2eb55f fix embedding model name translate issue (#16111) 2025-03-18 16:41:35 +08:00
7d620ffd5e Feat:app list dark mode (#16110) 2025-03-18 16:21:53 +08:00
6f6ba2f025 fix(api): enhance provider model records handling for missing langgenius providers (#16089) 2025-03-18 15:07:53 +08:00
33ba7e659b fix vector db sql injection (#16096) 2025-03-18 15:07:29 +08:00
750ec55646 doc: auto correct the doc using autocorrect close #16091 (#16092)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-03-18 14:57:14 +08:00
86d3fff666 fix: respect resolution settings for vision for basic chatbot, text generator, and parameter extractor node (#16041) 2025-03-18 14:37:07 +08:00
e91531fc23 fix: error in migrate_annotation_vector_database when exec vdb-migrate (#15937)
Co-authored-by: crazywoola <427733928@qq.com>
2025-03-18 14:15:48 +08:00
2524f16525 support config filename in meta for create_blob_message (#15605)
Co-authored-by: StoneFancyX <kindbin@qq.com>
Co-authored-by: crazywoola <427733928@qq.com>
2025-03-18 13:59:00 +08:00
cefec44070 feat: add app_mode field to app import and model definitions (#15729)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: twwu <twwu@dify.ai>
2025-03-18 11:12:25 +08:00
20376ca951 feat: upgrade knowledge metadata (#16063)
Support filter knowledge by metadata.

Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: NFish <douxc512@gmail.com>
2025-03-18 11:01:06 +08:00
475b8d731e Fix HTTP Request node to give priority to file extension of content-disposition (#12653) 2025-03-18 11:00:20 +08:00
963b6f628a Chore: PromptMessage is not an abstract base class (#15965) 2025-03-18 10:57:52 +08:00
63ea6f1ecf Fixed: Run failed: Failed to invoke tool: File.__init__() got an unexpected keyword argument (#14073)
Co-authored-by: hobo.l <hobo.l@binance.com>
2025-03-18 10:55:58 +08:00
947c9f70fb fix: improve InputNumber component step behavior and disabled state (#16044) 2025-03-18 10:42:29 +08:00
5e52d4d6b3 feat: add Maximum number of Parallelism branches to env (#15964)
Co-authored-by: Xiaoba Yu <xb1823725853@gmail.com>
2025-03-18 09:32:47 +08:00
939dcb4c0a chore: enhance ListWrapper and PluginPage components with stable scro… (#16048) 2025-03-18 09:12:49 +08:00
223ab5a38f feat: support openGauss vector database (#15865) 2025-03-17 19:42:54 +08:00
db7a37a111 fix: adjust position of table of contents in Doc component (#15996) 2025-03-17 19:37:21 +08:00
fe0d932f50 fix: fail-branch stream output error (#13401)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2025-03-17 19:35:37 +08:00
69fb0a4a28 chore: use POSIX shell syntax in pre-commit script (#16025) 2025-03-17 19:28:25 +08:00
04a0ae3aa9 feat: add llm blocking invoke (#15732) 2025-03-17 16:47:10 +08:00
e5d6047fb4 chore(api): Disable preview rules of Ruff while running pre-commit hook (#15999) 2025-03-17 16:40:27 +08:00
9e782d4c1e chore: bump ruff to 0.11.0 and fix linting violations (#15953) 2025-03-17 16:13:11 +08:00
98a4b3e78b fix: typo when assign doc_metadata when non-empty (#15975) 2025-03-17 14:14:07 +08:00
2b4d1cf1db fix(api): fix fail branch functionality for WorkflowTool (#15966) 2025-03-17 11:53:32 +08:00
fe76dfe1f8 When decrypt_trace_config is empty, it should be skipped directly (#15870) 2025-03-17 11:29:20 +08:00
c3774bef7e fix: api error of get all workspaces (#15880) 2025-03-17 11:22:27 +08:00
695a7400a9 fix:delete empty table bug (#15517)
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
2025-03-17 10:53:26 +08:00
e6a8800f66 fix: validation for upload methods of non-image files within the work… (#15932) 2025-03-17 09:50:21 +08:00
cee8731393 fix:Nginx template not replace env correctly (#15651) 2025-03-16 11:19:09 +08:00
4ae94dc027 Chore: fix wrong annotations (#15871) 2025-03-16 11:16:28 +08:00
3a69a6a452 Fix/enable marketplace bug (#15895) 2025-03-16 11:14:12 +08:00
f8f21ef7c0 fix: node use vision model may caused page crash (#15921) 2025-03-16 08:54:18 +08:00
0587eb4956 FIX:microsoft word text copy and paste error (#14905)
Co-authored-by: LinYing <linying@momenta.ai>
2025-03-14 18:31:20 +08:00
433374abea Chore: remove unused fields (#15764) 2025-03-14 18:13:25 +08:00
23ed3a520b chore(api): improve type hints for BaseNode and its subclasses (#15826) 2025-03-14 18:09:11 +08:00
5646442931 fix: iteration total tokens calculate error (#15813)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-03-14 17:44:24 +08:00
1a6298b6ea fix: Remove any extra Spaces in the title (#15841) 2025-03-14 17:12:29 +08:00
bf9b572bc3 fix tool selector with empty tools raise error (#15829) 2025-03-14 16:47:52 +08:00
cf72e53a10 chore: remove useless doc and font (#15838) 2025-03-14 16:47:42 +08:00
98bd79f548 fix: update Knowledge Api doc: 【Update a Chunk in a Document】 (#15823) 2025-03-14 16:45:20 +08:00
84a866028a fix document could be None (#15818) 2025-03-14 16:40:01 +08:00
10bd03611c Fix style of opening statement (#15821) 2025-03-14 15:50:28 +08:00
7c27d4b202 feat: add Http Request Node to skip ssl verify function #15177 (#15664) 2025-03-14 10:05:37 +08:00
8165d0b469 fix: http_request node form-data support array[file] (#15731) 2025-03-14 09:58:18 +08:00
e796937d02 feat: add keyboard shortcuts support for dialog confirmation (#15752) 2025-03-13 21:42:53 +08:00
49c952a631 fix: streamline file upload configuration handling in manager.py (#15714)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-13 16:32:49 +08:00
5f9d236d22 Feat: Add pg_bigm for keyword search in pgvector (#13876)
Signed-off-by: Yuichiro Utsumi <utsumi.yuichiro@fujitsu.com>
2025-03-13 16:32:34 +08:00
59f5a82261 fix: Resolve errors in SQL queries caused by SELECT fields not appearing in the GROUP BY clause. (#15659)
Co-authored-by: yuhang2.zhang <yuhang2.zhang@ly.com>
2025-03-13 16:06:42 +08:00
f22a1adb8b fix: Integration langfuse, front-end error( #15695) (#15709)
Co-authored-by: Xiaoba Yu <xb1823725853@gmail.com>
2025-03-13 15:43:41 +08:00
a8e8c37fdd improve text split (#15719) 2025-03-13 15:29:33 +08:00
37486a9cc6 fix: update default github star count value (#15708) 2025-03-13 14:39:26 +08:00
efebbffe96 Fix:webapp UI issues (#15601) 2025-03-13 14:23:41 +08:00
5e035a4209 Ci/deploy enterprise (#15699) 2025-03-13 02:22:21 -04:00
12fa517297 fix: if-else-node handles missing optional file variables (#15693) 2025-03-13 13:11:49 +08:00
36ae0e5476 fix: set score_threshold only when score_threshold_enabled is true. (#14221) 2025-03-12 20:55:57 +08:00
74f66d3119 Update .env.example to fix MILVUS_URI default value (#13140)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
2025-03-12 20:31:45 +08:00
Lam
adfaee7ab5 fix: prevent AppIconPicker click event from propagating (#15575) (#15647) 2025-03-12 20:03:09 +08:00
d37490adc3 fix dataset reranking mode miss (#15643) 2025-03-12 18:44:10 +08:00
087bb60b31 fix: preserve Unicode characters in keyword search queries (#15522)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-03-12 18:34:42 +08:00
5019547d33 fix: can not test custom tool (#15606) 2025-03-12 16:34:56 +08:00
Joe
58f012f3de fix: no attribute error (#15597) 2025-03-12 15:27:42 +08:00
b938c9b7f6 fix: trace return null cause page crash (#15588) 2025-03-12 14:40:43 +08:00
2b1facc7a6 fix: set marketplace feature to false in feature_service.py (#15578) 2025-03-12 14:13:41 +08:00
1d5ea80a2b feat: env MAX_TOOLS_NUM (#15431)
Co-authored-by: crazywoola <427733928@qq.com>
2025-03-12 12:57:05 +08:00
0415cc209d chore: use TenantAccountRole instead of TenantAccountJoinRole (#15514)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-03-12 12:56:30 +08:00
Joe
545e5cbcd6 fix: dataset editor (#15218) 2025-03-12 12:51:00 +08:00
1fab02c25a fix:message api doc (#15568)
Co-authored-by: mars <linjx2@by-health.com>
2025-03-12 12:38:23 +08:00
258736f505 chore: remove unused parameter (#15558) 2025-03-12 12:09:39 +08:00
Lam
0bc4da38fc feat: add debounced enter key submission to install form (#15445) (#15542) 2025-03-12 11:25:54 +08:00
037f200527 fix: invoke_error is not callable (#15555)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-03-12 10:58:44 +08:00
b541792465 fix: workflow loop node break conditions (#15549) 2025-03-12 10:10:51 +08:00
eb9b256ee8 fix: remove size prop in PlanBadge component because UpgradeBtn size … (#15544) 2025-03-12 09:49:15 +08:00
5d8b32a249 feat: add click-away and mounting logic to agent setting component (#15521) 2025-03-11 22:23:06 +08:00
c960b364c9 chore: update opendal version (#14343)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-03-11 20:44:09 +08:00
b817036343 fix: nesting of conditional branches causing streaming output error (#14065) 2025-03-11 20:30:03 +08:00
46036e6ce6 fix: update version to 1.0.1 in configuration and Docker files (#15478)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-11 18:50:42 +08:00
1ffda0dd34 fix notion page display (#15508) 2025-03-11 18:40:02 +08:00
da01b460fe support workspace billing info (#15510) 2025-03-11 18:38:23 +08:00
90a1508b87 fix: update placeholders in version info modal to indicate optional field (#15499) 2025-03-11 18:30:47 +08:00
b07016113c fix: add animation to workflow process loader icon (#15497) 2025-03-11 18:04:58 +08:00
d8317fcf81 fix: remove unnecessary modal (#15493) 2025-03-11 17:18:23 +08:00
a6bc642721 refactor: optimize provider configuration queries with provider name … (#15491) 2025-03-11 17:09:51 +08:00
b730f243dc fix: displan badge based on workspace plan (#15489) 2025-03-11 17:01:17 +08:00
71a57275ab fix: improve selection of variable in workflow (#15484)
Signed-off-by: Yuichiro Utsumi <utsumi.yuichiro@fujitsu.com>
2025-03-11 16:57:45 +08:00
41bf8d925f fix:To fix the issue of missing reference to body parameter (#15443)
Co-authored-by: crazywoola <427733928@qq.com>
2025-03-11 16:16:53 +08:00
6d172498d1 Update the provider_id validation to fix the error message displayed … (#15466)
Co-authored-by: Kyle Chang <kylechang@91app.com>
2025-03-11 16:11:24 +08:00
cad58658c2 fix: simplify S3 client configuration by removing redundant checksum settings (#15474)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-11 14:50:03 +08:00
a58b990855 fix agent_execution_metadata (#15444) 2025-03-11 14:35:08 +08:00
b6b1903a37 fix: fix chatbot publish and restore handling (#15462) 2025-03-11 13:36:45 +08:00
ed5596a8f4 fix: avoid llm node result var not init issue while do retry. (#14286) 2025-03-11 12:43:24 +08:00
49d0acd188 fix: replace old-style <br> tags to fix Mermaid rendering issues (#13792) 2025-03-11 12:40:55 +08:00
58a74fe1fb chore: add comment to the PLUGIN_DIFY_INNER_API_KEY key (#15381) 2025-03-11 00:25:11 +08:00
a1ab4aec3d fix db migration (#15422) 2025-03-11 00:24:57 +08:00
f77f7e1437 fix text split (#15426) 2025-03-11 00:24:27 +08:00
adda049265 fix kb permission (#15199)
Signed-off-by: kenwoodjw <blackxin55@gmail.com>
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-03-10 23:47:45 +08:00
9b2a9260ef Feat/new saas billing (#14996) 2025-03-10 19:50:11 +08:00
Joe
c8cc31af88 fix: app trace permission (#15397) 2025-03-10 18:45:25 +08:00
d333de274f chore(.github): add a new tracker template (#15391)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-10 18:39:35 +08:00
9e220d5d30 Feat: configure dark mode legacy (#15394) 2025-03-10 16:41:06 +08:00
2cf0cb471f fix: fix document list overlap and optimize document list fetching (#15377) 2025-03-10 15:34:40 +08:00
269ba6add9 fix: remove port expose on db (#15286) 2025-03-10 15:01:34 +08:00
78d460a6d1 Feat: time period filter for workflow logs (#14271)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-03-10 14:02:58 +08:00
3254018ddb feat(workflow_service): workflow version control api. (#14860)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-10 13:34:31 +08:00
f2b7df94d7 fix: return absolute path as the icon url if CONSOLE_API_URL is empty (#15279) 2025-03-10 13:15:06 +08:00
59fd3aad31 feat: add PIP_MIRROR_URL environment variable support (#15353) 2025-03-10 12:59:31 +08:00
355 changed files with 11272 additions and 4226 deletions

13
.github/ISSUE_TEMPLATE/tracker.yml vendored Normal file
View File

@ -0,0 +1,13 @@
name: "👾 Tracker"
description: For inner usages, please donot use this template.
title: "[Tracker] "
labels:
- tracker
body:
- type: textarea
id: content
attributes:
label: Blockers
placeholder: "- [ ] ..."
validations:
required: true

View File

@ -5,6 +5,7 @@ on:
branches:
- "main"
- "deploy/dev"
- "deploy/enterprise"
release:
types: [published]

29
.github/workflows/deploy-enterprise.yml vendored Normal file
View File

@ -0,0 +1,29 @@
name: Deploy Enterprise
permissions:
contents: read
on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/enterprise"
types:
- completed
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/enterprise'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.ENTERPRISE_SSH_HOST }}
username: ${{ secrets.ENTERPRISE_SSH_USER }}
password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }}
script: |
${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }}

View File

@ -10,5 +10,6 @@ yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-com
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"

View File

@ -76,6 +76,7 @@ jobs:
milvus-standalone
pgvecto-rs
pgvector
opengauss
chroma
elasticsearch

3
.gitignore vendored
View File

@ -202,3 +202,6 @@ api/.vscode
# plugin migrate
plugins.jsonl
# mise
mise.toml

View File

@ -26,7 +26,7 @@
| [@jyong](https://github.com/JohnJyong) | RAG 流水线设计 |
| [@GarfieldDai](https://github.com/GarfieldDai) | 构建 workflow 编排 |
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | 让我们的前端更易用 |
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验, 综合事项联系人 |
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验综合事项联系人 |
| [@takatost](https://github.com/takatost) | 产品整体方向和架构 |
事项优先级:
@ -47,7 +47,7 @@
| ------------------------------------------------------------ | --------------- |
| 核心功能的 Bugs例如无法登录、应用无法工作、安全漏洞 | 紧急 |
| 非紧急 bugs, 性能提升 | 中等优先级 |
| 小幅修复(错别字, 能正常工作但存在误导的 UI) | 低优先级 |
| 小幅修复 (错别字能正常工作但存在误导的 UI) | 低优先级 |
## 安装

View File

@ -79,7 +79,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
广泛的 RAG 功能,涵盖从文档摄入到检索的所有内容,支持从 PDF、PPT 和其他常见文档格式中提取文本的开箱即用的支持。
**5. Agent 智能体**:
您可以基于 LLM 函数调用或 ReAct 定义 Agent并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了50多种内置工具如谷歌搜索、DALL·E、Stable Diffusion 和 WolframAlpha 等。
您可以基于 LLM 函数调用或 ReAct 定义 Agent并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了 50 多种内置工具如谷歌搜索、DALL·E、Stable Diffusion 和 WolframAlpha 等。
**6. LLMOps**:
随时间监视和分析应用程序日志和性能。您可以根据生产数据和标注持续改进提示、数据集和模型。
@ -112,7 +112,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
<td align="center">仅限 OpenAI</td>
</tr>
<tr>
<td align="center">RAG引擎</td>
<td align="center">RAG 引擎</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
@ -234,7 +234,7 @@ docker compose up -d
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。
> 我们正在寻找贡献者来帮助将Dify翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。
> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。
**Contributors**

View File

@ -137,7 +137,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss
VECTOR_STORE=weaviate
# Weaviate configuration
@ -298,6 +298,14 @@ OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
# openGauss configuration
OPENGAUSS_HOST=127.0.0.1
OPENGAUSS_PORT=6600
OPENGAUSS_USER=postgres
OPENGAUSS_PASSWORD=Dify@123
OPENGAUSS_DATABASE=dify
OPENGAUSS_MIN_CONNECTION=1
OPENGAUSS_MAX_CONNECTION=5
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
@ -378,6 +386,7 @@ HTTP_REQUEST_MAX_READ_TIMEOUT=600
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
HTTP_REQUEST_NODE_SSL_VERIFY=True
# Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false
@ -444,4 +453,4 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400
LOGIN_LOCKOUT_DURATION=86400

View File

@ -56,8 +56,6 @@ RUN \
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install a chinese font to support the use of tools like matplotlib
fonts-noto-cjk \
# install a package to improve the accuracy of guessing mime type and file extension
media-types \
# install libmagic to support the use of python-magic guess MIMETYPE

View File

@ -160,11 +160,17 @@ def migrate_annotation_vector_database():
while True:
try:
# get apps info
per_page = 50
apps = (
App.query.filter(App.status == "normal")
db.session.query(App)
.filter(App.status == "normal")
.order_by(App.created_at.desc())
.paginate(page=page, per_page=50)
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
if not apps:
break
except NotFound:
break
@ -267,6 +273,7 @@ def migrate_knowledge_vector_database():
VectorType.WEAVIATE,
VectorType.ORACLE,
VectorType.ELASTICSEARCH,
VectorType.OPENGAUSS,
}
lower_collection_vector_types = {
VectorType.ANALYTICDB,

View File

@ -332,6 +332,11 @@ class HttpConfig(BaseSettings):
default=1 * 1024 * 1024,
)
HTTP_REQUEST_NODE_SSL_VERIFY: bool = Field(
description="Enable or disable SSL verification for HTTP requests",
default=True,
)
SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field(
description="Maximum number of retries for network requests (SSRF)",
default=3,

View File

@ -26,6 +26,7 @@ from .vdb.lindorm_config import LindormConfig
from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig
from .vdb.oceanbase_config import OceanBaseVectorConfig
from .vdb.opengauss_config import OpenGaussConfig
from .vdb.opensearch_config import OpenSearchConfig
from .vdb.oracle_config import OracleConfig
from .vdb.pgvector_config import PGVectorConfig
@ -281,5 +282,6 @@ class MiddlewareConfig(
LindormConfig,
OceanBaseVectorConfig,
BaiduVectorDBConfig,
OpenGaussConfig,
):
pass

View File

@ -0,0 +1,45 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class OpenGaussConfig(BaseSettings):
"""
Configuration settings for OpenGauss
"""
OPENGAUSS_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenGauss server(e.g., 'localhost')",
default=None,
)
OPENGAUSS_PORT: PositiveInt = Field(
description="Port number on which the OpenGauss server is listening (default is 6600)",
default=6600,
)
OPENGAUSS_USER: Optional[str] = Field(
description="Username for authenticating with the OpenGauss database",
default=None,
)
OPENGAUSS_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the OpenGauss database",
default=None,
)
OPENGAUSS_DATABASE: Optional[str] = Field(
description="Name of the OpenGauss database to connect to",
default=None,
)
OPENGAUSS_MIN_CONNECTION: PositiveInt = Field(
description="Min connection of the OpenGauss database",
default=1,
)
OPENGAUSS_MAX_CONNECTION: PositiveInt = Field(
description="Max connection of the OpenGauss database",
default=5,
)

View File

@ -43,3 +43,8 @@ class PGVectorConfig(BaseSettings):
description="Max connection of the PostgreSQL database",
default=5,
)
PGVECTOR_PG_BIGM: bool = Field(
description="Whether to use pg_bigm module for full text search",
default=False,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="1.0.0",
default="1.1.0",
)
COMMIT_SHA: str = Field(

View File

@ -81,6 +81,7 @@ from .datasets import (
datasets_segments,
external,
hit_testing,
metadata,
website,
)

View File

@ -316,7 +316,7 @@ class AppTraceApi(Resource):
@account_initialization_required
def post(self, app_id):
# add app trace
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enabled", type=bool, required=True, location="json")

View File

@ -1,8 +1,10 @@
import json
import logging
from typing import cast
from flask import abort, request
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@ -13,6 +15,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from factories import variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
@ -24,7 +27,7 @@ from models.account import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
@ -439,10 +442,38 @@ class PublishedWorkflowApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
args = parser.parse_args()
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
workflow_service = WorkflowService()
with Session(db.engine) as session:
workflow = workflow_service.publish_workflow(
session=session,
app_model=app_model,
account=current_user,
marked_name=args.marked_name or "",
marked_comment=args.marked_comment or "",
)
app_model.workflow_id = workflow.id
db.session.commit()
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
"created_at": workflow_created_at,
}
class DefaultBlockConfigsApi(Resource):
@ -564,37 +595,193 @@ class PublishedAllWorkflowApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("user_id", type=str, required=False, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
args = parser.parse_args()
page = args.get("page")
limit = args.get("limit")
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
user_id = args.get("user_id")
named_only = args.get("named_only", False)
if user_id:
if user_id != current_user.id:
raise Forbidden()
user_id = cast(str, user_id)
workflow_service = WorkflowService()
workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit)
with Session(db.engine) as session:
workflows, has_more = workflow_service.get_all_published_workflow(
session=session,
app_model=app_model,
page=page,
limit=limit,
user_id=user_id,
named_only=named_only,
)
return {"items": workflows, "page": page, "limit": limit, "has_more": has_more}
return {
"items": workflows,
"page": page,
"limit": limit,
"has_more": has_more,
}
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
class WorkflowByIdApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
def patch(self, app_model: App, workflow_id: str):
"""
Update workflow attributes
"""
# Check permission
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
args = parser.parse_args()
# Prepare update data
update_data = {}
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
if not update_data:
return {"message": "No valid fields to update"}, 400
workflow_service = WorkflowService()
# Create a session and manage the transaction
with Session(db.engine, expire_on_commit=False) as session:
workflow = workflow_service.update_workflow(
session=session,
workflow_id=workflow_id,
tenant_id=app_model.tenant_id,
account_id=current_user.id,
data=update_data,
)
if not workflow:
raise NotFound("Workflow not found")
# Commit the transaction in the controller
session.commit()
return workflow
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def delete(self, app_model: App, workflow_id: str):
"""
Delete workflow
"""
# Check permission
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
# Create a session and manage the transaction
with Session(db.engine) as session:
try:
workflow_service.delete_workflow(
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
)
# Commit the transaction in the controller
session.commit()
except WorkflowInUseError as e:
abort(400, description=str(e))
except DraftWorkflowDeletionError as e:
abort(400, description=str(e))
except ValueError as e:
raise NotFound(str(e))
return None, 204
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
)
api.add_resource(
WorkflowConfigApi,
"/apps/<uuid:app_id>/workflows/draft/config",
)
api.add_resource(
AdvancedChatDraftWorkflowRunApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
)
api.add_resource(
DraftWorkflowRunApi,
"/apps/<uuid:app_id>/workflows/draft/run",
)
api.add_resource(
WorkflowTaskStopApi,
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
DraftWorkflowNodeRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
WorkflowDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(WorkflowDraftRunLoopNodeApi, "/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run")
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(PublishedAllWorkflowApi, "/apps/<uuid:app_id>/workflows")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
WorkflowDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedWorkflowApi,
"/apps/<uuid:app_id>/workflows/publish",
)
api.add_resource(
PublishedAllWorkflowApi,
"/apps/<uuid:app_id>/workflows",
)
api.add_resource(
DefaultBlockConfigsApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultBlockConfigApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
ConvertToWorkflowApi,
"/apps/<uuid:app_id>/convert-to-workflow",
)
api.add_resource(
WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
)
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

View File

@ -1,13 +1,18 @@
from datetime import datetime
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required
from models import App
from models.model import AppMode
from models.workflow import WorkflowRunStatus
from services.workflow_app_service import WorkflowAppService
@ -24,17 +29,38 @@ class WorkflowAppLogApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model, args=args
)
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
keyword=args.keyword,
status=args.status,
created_at_before=args.created_at__before,
created_at_after=args.created_at__after,
page=args.page,
limit=args.limit,
)
return workflow_app_log_pagination
return workflow_app_log_pagination
api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs")

View File

@ -122,7 +122,7 @@ class DataSourceNotionListApi(Resource):
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
documents = session.execute(
documents = session.scalars(
select(Document).filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,

View File

@ -10,7 +10,12 @@ from controllers.console import api
from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
enterprise_license_required,
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
@ -96,6 +101,7 @@ class DatasetListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
@ -178,6 +184,10 @@ class DatasetApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if dataset.indexing_technique == "high_quality":
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
data["embedding_model_provider"] = str(provider_id)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
@ -210,6 +220,7 @@ class DatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -276,7 +287,11 @@ class DatasetApi(Resource):
data = request.get_json()
# check embedding model setting
if data.get("indexing_technique") == "high_quality":
if (
data.get("indexing_technique") == "high_quality"
and data.get("embedding_model_provider") is not None
and data.get("embedding_model") is not None
):
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
@ -313,6 +328,7 @@ class DatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)
@ -647,6 +663,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.LINDORM
| VectorType.COUCHBASE
| VectorType.MILVUS
| VectorType.OPENGAUSS
):
return {
"retrieval_method": [
@ -690,6 +707,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.LINDORM
| VectorType.OPENGAUSS
):
return {
"retrieval_method": [

View File

@ -26,6 +26,7 @@ from controllers.console.datasets.error import (
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
@ -242,6 +243,7 @@ class DatasetDocumentListApi(Resource):
@account_initialization_required
@marshal_with(documents_and_batch_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
dataset_id = str(dataset_id)
@ -297,6 +299,7 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -320,9 +323,10 @@ class DatasetInitApi(Resource):
@account_initialization_required
@marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
parser = reqparse.RequestParser()
@ -617,7 +621,7 @@ class DocumentDetailApi(DocumentResource):
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
@ -678,7 +682,7 @@ class DocumentDetailApi(DocumentResource):
"disabled_by": document.disabled_by,
"archived": document.archived,
"doc_type": document.doc_type,
"doc_metadata": document.doc_metadata,
"doc_metadata": document.doc_metadata_details,
"segment_count": document.segment_count,
"average_segment_length": document.average_segment_length,
"hit_count": document.hit_count,
@ -694,13 +698,14 @@ class DocumentProcessingApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
if action == "pause":
@ -730,6 +735,7 @@ class DocumentDeleteApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
@ -763,8 +769,8 @@ class DocumentMetadataApi(DocumentResource):
doc_type = req_data.get("doc_type")
doc_metadata = req_data.get("doc_metadata")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
if doc_type is None or doc_metadata is None:
@ -798,6 +804,7 @@ class DocumentStatusApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -893,6 +900,7 @@ class DocumentPauseApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id = str(dataset_id)
@ -925,6 +933,7 @@ class DocumentRecoverApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id = str(dataset_id)
@ -954,6 +963,7 @@ class DocumentRetryApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""retry document."""

View File

@ -19,6 +19,7 @@ from controllers.console.datasets.error import (
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
@ -106,6 +107,7 @@ class DatasetDocumentSegmentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
@ -121,8 +123,8 @@ class DatasetDocumentSegmentListApi(Resource):
raise NotFound("Document not found.")
segment_ids = request.args.getlist("segment_id")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -137,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -148,8 +151,8 @@ class DatasetDocumentSegmentApi(Resource):
raise NotFound("Document not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
try:
@ -191,6 +194,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
@ -202,7 +206,7 @@ class DatasetDocumentSegmentAddApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
@ -240,6 +244,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@ -276,8 +281,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
).first()
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -299,6 +304,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@ -319,8 +325,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
).first()
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -336,6 +342,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
@ -402,6 +409,7 @@ class ChildChunkAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@ -420,7 +428,7 @@ class ChildChunkAddApi(Resource):
).first()
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
@ -499,6 +507,7 @@ class ChildChunkAddApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@ -519,8 +528,8 @@ class ChildChunkAddApi(Resource):
).first()
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -542,6 +551,7 @@ class ChildChunkUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
@ -569,8 +579,8 @@ class ChildChunkUpdateApi(Resource):
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -586,6 +596,7 @@ class ChildChunkUpdateApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
@ -613,8 +624,8 @@ class ChildChunkUpdateApi(Resource):
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)

View File

@ -2,7 +2,11 @@ from flask_restful import Resource # type: ignore
from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from libs.login import login_required
@ -10,6 +14,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
dataset_id_str = str(dataset_id)

View File

@ -0,0 +1,155 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)
from services.metadata_service import MetadataService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetMetadataCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def post(self, dataset_id):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
metadata_args = MetadataArgs(**args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return metadata, 201
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return MetadataService.get_dataset_metadatas(dataset), 200
class DatasetMetadataApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def patch(self, dataset_id, metadata_id):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
return metadata, 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, dataset_id, metadata_id):
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return 200
class DatasetMetadataBuiltInFieldApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200
class DatasetMetadataBuiltInFieldActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id, action):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
return 200
class DocumentMetadataEditApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
args = parser.parse_args()
metadata_args = MetadataOperationData(**args)
MetadataService.update_documents_metadata(dataset, metadata_args)
return 200
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")

View File

@ -26,6 +26,7 @@ from libs.helper import TimestampField
from libs.login import login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workspace_service import WorkspaceService
@ -68,6 +69,11 @@ class TenantListApi(Resource):
tenants = TenantService.get_join_tenants(current_user)
for tenant in tenants:
features = FeatureService.get_features(tenant.id)
if features.billing.enabled:
tenant.plan = features.billing.subscription.plan
else:
tenant.plan = "sandbox"
if tenant.id == current_user.current_tenant_id:
tenant.current = True # Set current=True for current tenant
return {"workspaces": marshal(tenants, tenants_fields)}, 200
@ -82,28 +88,20 @@ class WorkspaceListApi(Resource):
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(page=args["page"], per_page=args["limit"])
tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(
page=args["page"], per_page=args["limit"], error_out=False
)
has_more = False
if len(tenants.items) == args["limit"]:
current_page_first_tenant = tenants[-1]
rest_count = (
db.session.query(Tenant)
.filter(
Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id
)
.count()
)
if rest_count > 0:
has_more = True
total = db.session.query(Tenant).count()
if tenants.has_next:
has_more = True
return {
"data": marshal(tenants.items, workspace_fields),
"has_more": has_more,
"limit": args["limit"],
"page": args["page"],
"total": total,
"total": tenants.total,
}, 200

View File

@ -1,5 +1,6 @@
import json
import os
import time
from functools import wraps
from flask import abort, request
@ -8,6 +9,8 @@ from flask_login import current_user # type: ignore
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
@ -67,7 +70,9 @@ def cloud_edition_billing_resource_check(resource: str):
elif resource == "apps" and 0 < apps.limit <= apps.size:
abort(403, "The number of apps has reached the limit of your subscription.")
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
abort(
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets
@ -112,6 +117,41 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{current_user.current_tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=current_user.current_tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
abort(
403, "Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)
return decorated
return interceptor
def cloud_utm_record(view):
@wraps(view)
def decorated(*args, **kwargs):

View File

@ -10,7 +10,7 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import message_file_fields
from fields.message_fields import feedback_fields, retriever_resource_fields
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser
@ -19,20 +19,6 @@ from services.message_service import MessageService
class MessageListApi(Resource):
agent_thought_fields = {
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields)),
}
message_fields = {
"id": fields.String,
"conversation_id": fields.String,

View File

@ -1,7 +1,9 @@
import logging
from datetime import datetime
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
@ -25,7 +27,7 @@ from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from models.workflow import WorkflowRun, WorkflowRunStatus
from services.app_generate_service import AppGenerateService
from services.workflow_app_service import WorkflowAppService
@ -125,17 +127,34 @@ class WorkflowAppLogApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("created_at__before", type=str, location="args")
parser.add_argument("created_at__after", type=str, location="args")
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model, args=args
)
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
keyword=args.keyword,
status=args.status,
created_at_before=args.created_at__before,
created_at_after=args.created_at__after,
page=args.page,
limit=args.limit,
)
return workflow_app_log_pagination
return workflow_app_log_pagination
api.add_resource(WorkflowRunApi, "/workflows/run")

View File

@ -1,3 +1,4 @@
import time
from collections.abc import Callable
from datetime import UTC, datetime, timedelta
from enum import Enum
@ -13,8 +14,10 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import RateLimitLog
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
@ -139,6 +142,43 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{api_token.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=api_token.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
raise Forbidden(
"Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)
return decorated
return interceptor
def validate_dataset_token(view=None):
def decorator(view):
@wraps(view)

View File

@ -1,7 +1,12 @@
import uuid
from typing import Optional
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.app_config.entities import (
DatasetEntity,
DatasetRetrieveConfigEntity,
MetadataFilteringCondition,
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode
from services.dataset_service import DatasetService
@ -78,6 +83,15 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
else None,
),
)
else:
@ -89,11 +103,22 @@ class DatasetConfigManager:
dataset_configs["retrieval_model"]
),
top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get("score_threshold"),
score_threshold=dataset_configs.get("score_threshold")
if dataset_configs.get("score_threshold_enabled", False)
else None,
reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
else None,
),
)

View File

@ -1,10 +1,11 @@
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Any, Optional
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.file import FileTransferMethod, FileType, FileUploadConfig
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode
@ -135,6 +136,55 @@ class ExternalDataVariableEntity(BaseModel):
config: dict[str, Any] = Field(default_factory=dict)
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
completion_params: dict[str, Any] = {}
class Condition(BaseModel):
"""
Conditon detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel):
"""
Dataset Retrieve Config Entity.
@ -171,6 +221,9 @@ class DatasetRetrieveConfigEntity(BaseModel):
reranking_model: Optional[dict] = None
weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
class DatasetEntity(BaseModel):

View File

@ -17,17 +17,15 @@ class FileUploadConfigManager:
if file_upload_dict:
if file_upload_dict.get("enabled"):
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
data = {
"image_config": {
"number_limits": file_upload_dict["number_limits"],
"transfer_methods": transform_methods,
}
file_upload_dict["image_config"] = {
"number_limits": file_upload_dict.get("number_limits", 1),
"transfer_methods": transform_methods,
}
if is_vision:
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
file_upload_dict["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "high")
return FileUploadConfig.model_validate(data)
return FileUploadConfig.model_validate(file_upload_dict)
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:

View File

@ -151,7 +151,7 @@ class BaseAppGenerator:
def gen():
for message in generator:
if isinstance(message, (Mapping, dict)):
if isinstance(message, Mapping | dict):
yield f"data: {json.dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"

View File

@ -17,7 +17,11 @@ from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.moderation.input_moderation import InputModeration
@ -141,6 +145,7 @@ class AppRunner:
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Organize prompt messages
@ -167,6 +172,7 @@ class AppRunner:
context=context,
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
)
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
@ -201,6 +207,7 @@ class AppRunner:
memory_config=memory_config,
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
)
stop = model_config.stop

View File

@ -11,6 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db
@ -50,6 +51,16 @@ class ChatAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
image_detail_config = (
application_generate_entity.file_upload_config.image_config.detail
if (
application_generate_entity.file_upload_config
and application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
@ -85,6 +96,7 @@ class ChatAppRunner(AppRunner):
files=files,
query=query,
memory=memory,
image_detail_config=image_detail_config,
)
# moderation
@ -168,6 +180,7 @@ class ChatAppRunner(AppRunner):
hit_callback=hit_callback,
memory=memory,
message_id=message.id,
inputs=inputs,
)
# reorganize all inputs and template to prompt messages
@ -182,6 +195,7 @@ class ChatAppRunner(AppRunner):
query=query,
context=context,
memory=memory,
image_detail_config=image_detail_config,
)
# check hosting moderation

View File

@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db
@ -43,6 +44,16 @@ class CompletionAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
image_detail_config = (
application_generate_entity.file_upload_config.image_config.detail
if (
application_generate_entity.file_upload_config
and application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
@ -66,6 +77,7 @@ class CompletionAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
image_detail_config=image_detail_config,
)
# moderation
@ -127,6 +139,7 @@ class CompletionAppRunner(AppRunner):
show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,
)
# reorganize all inputs and template to prompt messages
@ -140,6 +153,7 @@ class CompletionAppRunner(AppRunner):
files=files,
query=query,
context=context,
image_detail_config=image_detail_config,
)
# check hosting moderation

View File

@ -7,7 +7,6 @@ from json import JSONDecodeError
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import or_
from constants import HIDDEN_VALUE
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel):
else [],
)
def _get_custom_provider_credentials(self) -> Provider | None:
"""
Get custom provider credentials.
"""
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
)
.first()
)
return provider_record
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
:return:
"""
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
if model_provider_id.is_langgenius():
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
or_(
Provider.provider_name == model_provider_id.provider_name,
Provider.provider_name == self.provider.provider,
),
)
.first()
)
else:
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name == self.provider.provider,
)
.first()
)
provider_record = self._get_custom_provider_credentials()
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
or_(
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
Provider.provider_name == self.provider.provider,
),
Provider.provider_type == ProviderType.CUSTOM.value,
)
.first()
)
provider_record = self._get_custom_provider_credentials()
# delete provider
if provider_record:
@ -349,6 +335,33 @@ class ProviderConfiguration(BaseModel):
return None
def _get_custom_model_credentials(
self,
model_type: ModelType,
model: str,
) -> ProviderModel | None:
"""
Get custom model credentials.
"""
# get provider model
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)
return provider_model_record
def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel | None, dict]:
@ -361,16 +374,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider model
provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)
provider_model_record = self._get_custom_model_credentials(model_type, model)
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider model
provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)
provider_model_record = self._get_custom_model_credentials(model_type, model)
# delete provider model
if provider_model_record:
@ -475,6 +470,26 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache.delete()
def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
"""
Get provider model setting.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
return (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Enable model.
@ -482,16 +497,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
model_setting = self._get_provider_model_setting(model_type, model)
if model_setting:
model_setting.enabled = True
@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
model_setting = self._get_provider_model_setting(model_type, model)
if model_setting:
model_setting.enabled = False
@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
return self._get_provider_model_setting(model_type, model)
def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
"""
Get load balancing config.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
return (
db.session.query(ProviderModelSetting)
db.session.query(LoadBalancingModelConfig)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.first()
)
@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel):
if load_balancing_config_count <= 1:
raise ValueError("Model load balancing configuration must be more than 1.")
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
model_setting = self._get_provider_model_setting(model_type, model)
if model_setting:
model_setting.load_balancing_enabled = True
@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel):
return
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
preferred_model_provider = (
db.session.query(TenantPreferredModelProvider)
.filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider,
TenantPreferredModelProvider.provider_name.in_(provider_names),
)
.first()
)

View File

@ -63,7 +63,9 @@ class File(BaseModel):
extension: Optional[str] = None,
mime_type: Optional[str] = None,
size: int = -1,
storage_key: str,
storage_key: Optional[str] = None,
dify_model_identity: Optional[str] = FILE_MODEL_IDENTITY,
url: Optional[str] = None,
):
super().__init__(
id=id,
@ -76,8 +78,10 @@ class File(BaseModel):
extension=extension,
mime_type=mime_type,
size=size,
dify_model_identity=dify_model_identity,
url=url,
)
self._storage_key = storage_key
self._storage_key = str(storage_key)
def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json")

View File

@ -11,6 +11,19 @@ from configs import dify_config
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
try:
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
if http_request_node_ssl_verify_lower == "true":
HTTP_REQUEST_NODE_SSL_VERIFY = True
elif http_request_node_ssl_verify_lower == "false":
HTTP_REQUEST_NODE_SSL_VERIFY = False
else:
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
except NameError:
HTTP_REQUEST_NODE_SSL_VERIFY = True
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
@ -39,17 +52,17 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
response = client.request(method=method, url=url, **kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
}
with httpx.Client(mounts=proxy_mounts) as client:
with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
response = client.request(method=method, url=url, **kwargs)
else:
with httpx.Client() as client:
with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:

View File

@ -493,7 +493,7 @@ If inputting a combination of text and images, the images need to be constructed
The base class for all Role message bodies, used only for parameter declaration and cannot be initialized.
```python
class PromptMessage(ABC, BaseModel):
class PromptMessage(BaseModel):
"""
Model class for prompt message.
"""

View File

@ -533,7 +533,7 @@ class ImagePromptMessageContent(PromptMessageContent):
所有 Role 消息体的基类,仅作为参数声明用,不可初始化。
```python
class PromptMessage(ABC, BaseModel):
class PromptMessage(BaseModel):
"""
Model class for prompt message.
"""

View File

@ -31,11 +31,9 @@ __all__ = [
"ModelPropertyKey",
"MultiModalPromptMessageContent",
"PromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageContentType",
"PromptMessageRole",
"PromptMessageRole",
"PromptMessageTool",
"SystemPromptMessage",
"TextPromptMessageContent",

View File

@ -1,4 +1,3 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Optional
@ -119,7 +118,7 @@ class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
class PromptMessage(ABC, BaseModel):
class PromptMessage(BaseModel):
"""
Model class for prompt message.
"""

View File

@ -80,7 +80,7 @@ class AIModel(BaseModel):
)
)
elif isinstance(invoke_error, InvokeError):
return invoke_error(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
else:
return error

View File

@ -214,6 +214,8 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["trace_instance"],
provider_config_map[tracing_provider]["config_class"],
)
if not decrypt_trace_config:
return None
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
return tracing_instance

View File

@ -3,7 +3,7 @@ from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
@ -46,7 +46,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
model_parameters=payload.completion_params,
tools=payload.tools,
stop=payload.stop,
stream=payload.stream or True,
stream=True if payload.stream is None else payload.stream,
user=user_id,
)
@ -64,7 +64,21 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
else:
if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
return response
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=response.model,
prompt_messages=response.prompt_messages,
system_fingerprint=response.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle_non_streaming(response)
@classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):

View File

@ -147,7 +147,7 @@ def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: An
init frontend parameter by rule
"""
parameter_value = value
if not parameter_value and parameter_value != 0:
if not parameter_value and parameter_value != 0 and type != PluginParameterType.TOOLS_SELECTOR:
# get default value
parameter_value = rule.default
if not parameter_value and rule.required:

View File

@ -46,6 +46,7 @@ class AdvancedPromptTransform(PromptTransform):
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> list[PromptMessage]:
prompt_messages = []
@ -59,6 +60,7 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
)
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
prompt_messages = self._get_chat_model_prompt_messages(
@ -70,6 +72,7 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
)
return prompt_messages
@ -84,6 +87,7 @@ class AdvancedPromptTransform(PromptTransform):
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> list[PromptMessage]:
"""
Get completion model prompt messages.
@ -124,7 +128,9 @@ class AdvancedPromptTransform(PromptTransform):
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
@ -142,6 +148,7 @@ class AdvancedPromptTransform(PromptTransform):
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> list[PromptMessage]:
"""
Get chat model prompt messages.
@ -197,7 +204,9 @@ class AdvancedPromptTransform(PromptTransform):
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file in files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
@ -209,19 +218,25 @@ class AdvancedPromptTransform(PromptTransform):
# get last user message content and add files
prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
for file in files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
last_message.content = prompt_message_contents
else:
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
for file in files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file in files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
elif query:

View File

@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.file import file_manager
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
SystemPromptMessage,
@ -60,6 +61,7 @@ class SimplePromptTransform(PromptTransform):
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = {key: str(value) for key, value in inputs.items()}
@ -74,6 +76,7 @@ class SimplePromptTransform(PromptTransform):
context=context,
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
)
else:
prompt_messages, stops = self._get_completion_model_prompt_messages(
@ -85,6 +88,7 @@ class SimplePromptTransform(PromptTransform):
context=context,
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
)
return prompt_messages, stops
@ -175,6 +179,7 @@ class SimplePromptTransform(PromptTransform):
files: Sequence["File"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
prompt_messages: list[PromptMessage] = []
@ -204,9 +209,9 @@ class SimplePromptTransform(PromptTransform):
)
if query:
prompt_messages.append(self.get_last_user_message(query, files))
prompt_messages.append(self.get_last_user_message(query, files, image_detail_config))
else:
prompt_messages.append(self.get_last_user_message(prompt, files))
prompt_messages.append(self.get_last_user_message(prompt, files, image_detail_config))
return prompt_messages, None
@ -220,6 +225,7 @@ class SimplePromptTransform(PromptTransform):
files: Sequence["File"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules(
@ -262,14 +268,21 @@ class SimplePromptTransform(PromptTransform):
if stops is not None and len(stops) == 0:
stops = None
return [self.get_last_user_message(prompt, files)], stops
return [self.get_last_user_message(prompt, files, image_detail_config)], stops
def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage:
def get_last_user_message(
self,
prompt: str,
files: Sequence["File"],
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> UserPromptMessage:
if files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:

View File

@ -149,6 +149,11 @@ class ProviderManager:
provider_name = provider_entity.provider
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
provider_id_entity = ModelProviderID(provider_name)
if provider_id_entity.is_langgenius():
provider_model_records.extend(
provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, [])
)
# Convert to custom configuration
custom_configuration = self._to_custom_configuration(
@ -190,6 +195,20 @@ class ProviderManager:
provider_name
)
provider_id_entity = ModelProviderID(provider_name)
if provider_id_entity.is_langgenius():
if provider_model_settings is not None:
provider_model_settings.extend(
provider_name_to_provider_model_settings_dict.get(provider_id_entity.provider_name, [])
)
if provider_load_balancing_configs is not None:
provider_load_balancing_configs.extend(
provider_name_to_provider_load_balancing_model_configs_dict.get(
provider_id_entity.provider_name, []
)
)
# Convert to model settings
model_settings = self._to_model_settings(
provider_entity=provider_entity,
@ -207,7 +226,7 @@ class ProviderManager:
model_settings=model_settings,
)
provider_configurations[str(ModelProviderID(provider_name))] = provider_configuration
provider_configurations[str(provider_id_entity)] = provider_configuration
# Return the encapsulated object
return provider_configurations

View File

@ -88,16 +88,17 @@ class Jieba(BaseKeyword):
keyword_table = self._get_dataset_keyword_table()
k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
.first()
segment_query = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
if segment:
documents.append(

View File

@ -1,5 +1,4 @@
import concurrent.futures
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
@ -42,6 +41,7 @@ class RetrievalService:
reranking_model: Optional[dict] = None,
reranking_mode: str = "reranking_model",
weights: Optional[dict] = None,
document_ids_filter: Optional[list[str]] = None,
):
if not query:
return []
@ -65,6 +65,7 @@ class RetrievalService:
top_k=top_k,
all_documents=all_documents,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
@ -80,6 +81,7 @@ class RetrievalService:
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
@ -131,7 +133,14 @@ class RetrievalService:
@classmethod
def keyword_search(
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
):
with flask_app.app_context():
try:
@ -140,7 +149,10 @@ class RetrievalService:
raise ValueError("dataset not found")
keyword = Keyword(dataset=dataset)
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
documents = keyword.search(
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
)
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@ -157,6 +169,7 @@ class RetrievalService:
all_documents: list,
retrieval_method: str,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
):
with flask_app.app_context():
try:
@ -171,6 +184,7 @@ class RetrievalService:
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
if documents:
@ -243,7 +257,7 @@ class RetrievalService:
@staticmethod
def escape_query_for_search(query: str) -> str:
return json.dumps(query).strip('"')
return query.replace('"', '\\"')
@classmethod
def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]:
@ -277,6 +291,8 @@ class RetrievalService:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents

View File

@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector):
self.analyticdb_vector.delete_by_metadata_field(key, value)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_vector(query_vector)
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_full_text(query, **kwargs)

View File

@ -194,6 +194,13 @@ class AnalyticdbVectorBySql:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = "WHERE 1=1"
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0)
with self._get_cursor() as cur:
query_vector_str = json.dumps(query_vector)
@ -202,7 +209,7 @@ class AnalyticdbVectorBySql:
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
f"t.page_content as page_content, t.metadata_ AS metadata_ "
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
(query_vector_str,),
)
documents = []
@ -220,12 +227,19 @@ class AnalyticdbVectorBySql:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
with self._get_cursor() as cur:
cur.execute(
f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
ORDER BY score DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),

View File

@ -123,11 +123,21 @@ class BaiduVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
filter=f"document_id IN ({document_ids})",
)
else:
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
res = self._db.table(self._collection_name).search(
anns=anns,
projections=[self.field_id, self.field_text, self.field_metadata],

View File

@ -95,7 +95,15 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
results: QueryResult = collection.query(
query_embeddings=query_vector,
n_results=kwargs.get("top_k", 4),
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
)
else:
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# Check if results contain data

View File

@ -117,6 +117,9 @@ class ElasticSearchVector(BaseVector):
top_k = kwargs.get("top_k", 4)
num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
@ -145,6 +148,9 @@ class ElasticSearchVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY.value: query}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:

View File

@ -168,7 +168,12 @@ class LindormVectorStore(BaseVector):
raise ValueError("All elements in query_vector should be floats")
top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
document_ids_filter = kwargs.get("document_ids_filter")
filters = []
if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
try:
params = {}
if self._using_ugc:
@ -206,7 +211,10 @@ class LindormVectorStore(BaseVector):
should = kwargs.get("should")
minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter")
filters = kwargs.get("filter", [])
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
routing = self._routing
full_text_query = default_text_search_query(
query_text=query,

View File

@ -228,12 +228,18 @@ class MilvusVector(BaseVector):
"""
Search for documents by vector similarity.
"""
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] in ({document_ids})'
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
filter=filter,
)
return self._process_search_results(
@ -249,6 +255,11 @@ class MilvusVector(BaseVector):
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] in ({document_ids})'
results = self._client.search(
collection_name=self._collection_name,
@ -256,6 +267,7 @@ class MilvusVector(BaseVector):
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
filter=filter,
)
return self._process_search_results(

View File

@ -125,12 +125,18 @@ class MyScaleVector(BaseVector):
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
score_threshold = float(kwargs.get("score_threshold") or 0.0)
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
else ""
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
sql = f"""
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}

View File

@ -154,6 +154,11 @@ class OceanBaseVector(BaseVector):
return []
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = None
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search)
@ -167,6 +172,7 @@ class OceanBaseVector(BaseVector):
distance_func=func.l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
where_clause=where_clause,
)
docs = []
for text, metadata, distance in cur:

View File

@ -0,0 +1,240 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class OpenGaussConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
min_connection: int
max_connection: int
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config OPENGAUSS_HOST is required")
if not values["port"]:
raise ValueError("config OPENGAUSS_PORT is required")
if not values["user"]:
raise ValueError("config OPENGAUSS_USER is required")
if not values["password"]:
raise ValueError("config OPENGAUSS_PASSWORD is required")
if not values["database"]:
raise ValueError("config OPENGAUSS_DATABASE is required")
if not values["min_connection"]:
raise ValueError("config OPENGAUSS_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config OPENGAUSS_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION")
return values
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding vector({dimension}) NOT NULL
);
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""
class OpenGauss(BaseVector):
def __init__(self, collection_name: str, config: OpenGaussConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
def get_type(self) -> str:
return VectorType.OPENGAUSS
def _create_connection_pool(self, config: OpenGaussConfig):
return psycopg2.pool.SimpleConnectionPool(
config.min_connection,
config.max_connection,
host=config.host,
port=config.port,
user=config.user,
password=config.password,
database=config.database,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_values(
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
)
return pks
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None
def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
def delete_by_ids(self, ids: list[str]) -> None:
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items.
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
f" ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),),
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
docs = []
for record in cur:
metadata, text, score = record
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
if dimension <= 2000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class OpenGaussFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenGauss:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENGAUSS, collection_name))
return OpenGauss(
collection_name=collection_name,
config=OpenGaussConfig(
host=dify_config.OPENGAUSS_HOST or "localhost",
port=dify_config.OPENGAUSS_PORT,
user=dify_config.OPENGAUSS_USER or "postgres",
password=dify_config.OPENGAUSS_PASSWORD or "",
database=dify_config.OPENGAUSS_DATABASE or "dify",
min_connection=dify_config.OPENGAUSS_MIN_CONNECTION,
max_connection=dify_config.OPENGAUSS_MAX_CONNECTION,
),
)

View File

@ -154,6 +154,9 @@ class OpenSearchVector(BaseVector):
"size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
try:
response = self._client.search(index=self._collection_name.lower(), body=query)
@ -179,6 +182,9 @@ class OpenSearchVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)

View File

@ -201,10 +201,15 @@ class OracleVector(BaseVector):
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
f" ORDER BY distance fetch first {top_k} rows only",
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
[numpy.array(query_vector)],
)
docs = []
@ -257,9 +262,15 @@ class OracleVector(BaseVector):
if token not in stop_words:
entities.append(token)
with self._get_cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
cur.execute(
f"select meta, text, embedding FROM {self.table_name}"
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
f"order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)],
)
docs = []

View File

@ -189,6 +189,9 @@ class PGVectoRS(BaseVector):
.limit(kwargs.get("top_k", 4))
.order_by("distance")
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
res = session.execute(stmt)
results = [(row[0], row[1]) for row in res]

View File

@ -1,8 +1,10 @@
import json
import logging
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.errors
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
@ -25,6 +27,7 @@ class PGVectorConfig(BaseModel):
database: str
min_connection: int
max_connection: int
pg_bigm: bool = False
@model_validator(mode="before")
@classmethod
@ -62,12 +65,18 @@ CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""
SQL_CREATE_INDEX_PG_BIGM = """
CREATE INDEX IF NOT EXISTS bigm_idx ON {table_name}
USING gin (text gin_bigm_ops);
"""
class PGVector(BaseVector):
def __init__(self, collection_name: str, config: PGVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
self.pg_bigm = config.pg_bigm
def get_type(self) -> str:
return VectorType.PGVECTOR
@ -140,7 +149,14 @@ class PGVector(BaseVector):
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
except psycopg2.errors.UndefinedTable:
# table not exists
logging.warning(f"Table {self.table_name} not found, skipping delete operation.")
return
except Exception as e:
raise e
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
@ -155,10 +171,18 @@ class PGVector(BaseVector):
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
f" {where_clause}"
f" ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),),
)
@ -174,17 +198,37 @@ class PGVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
if self.pg_bigm:
cur.execute("SET pg_bigm.similarity_limit TO 0.000001")
cur.execute(
f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score
FROM {self.table_name}
WHERE text =%% unistr(%s)
{where_clause}
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
else:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
{where_clause}
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
docs = []
@ -214,6 +258,9 @@ class PGVector(BaseVector):
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
if dimension <= 2000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
if self.pg_bigm:
cur.execute("CREATE EXTENSION IF NOT EXISTS pg_bigm")
cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -237,5 +284,6 @@ class PGVectorFactory(AbstractVectorFactory):
database=dify_config.PGVECTOR_DATABASE or "postgres",
min_connection=dify_config.PGVECTOR_MIN_CONNECTION,
max_connection=dify_config.PGVECTOR_MAX_CONNECTION,
pg_bigm=dify_config.PGVECTOR_PG_BIGM,
),
)

View File

@ -286,27 +286,26 @@ class QdrantVector(BaseVector):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchAny(any=ids),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []
@ -331,6 +330,15 @@ class QdrantVector(BaseVector):
),
],
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
if filter.must:
filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
@ -377,6 +385,15 @@ class QdrantVector(BaseVector):
),
]
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
if scroll_filter.must:
scroll_filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,

View File

@ -223,8 +223,12 @@ class RelytVector(BaseVector):
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = kwargs.get("filter", {})
if document_ids_filter:
filter["document_id"] = document_ids_filter
results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
)
# Organize results.
@ -246,9 +250,9 @@ class RelytVector(BaseVector):
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
if len(value) > 1
else f"metadata->>{key!r} = {value[0]!r}"
else f"metadata->>'{key!r}' = {value[0]!r}"
for key, value in filter.items()
]
filter_condition = f"WHERE {' AND '.join(conditions)}"

View File

@ -145,11 +145,16 @@ class TencentVector(BaseVector):
self._db.collection(self._collection_name).delete(document_ids=ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
res = self._db.collection(self._collection_name).search(
vectors=[query_vector],
filter=filter,
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),

View File

@ -326,6 +326,18 @@ class TidbOnQdrantVector(BaseVector):
),
],
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
should_conditions = []
for document_id_filter in document_ids_filter:
should_conditions.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id_filter),
)
)
if should_conditions:
filter.should = should_conditions # type: ignore
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
@ -368,6 +380,18 @@ class TidbOnQdrantVector(BaseVector):
)
]
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
should_conditions = []
for document_id_filter in document_ids_filter:
should_conditions.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id_filter),
)
)
if should_conditions:
scroll_filter.should = should_conditions # type: ignore
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,

View File

@ -196,6 +196,11 @@ class TiDBVector(BaseVector):
docs = []
tidb_dist_func = self._get_distance_func()
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
with Session(self._engine) as session:
select_statement = sql_text(f"""
@ -206,6 +211,7 @@ class TiDBVector(BaseVector):
text,
{tidb_dist_func}(vector, :query_vector_str) AS distance
FROM {self._collection_name}
{where_clause}
ORDER BY distance ASC
LIMIT :top_k
) t

View File

@ -88,7 +88,20 @@ class UpstashVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f"document_id in ({document_ids})"
else:
filter = ""
result = self.index.query(
vector=query_vector,
top_k=top_k,
include_metadata=True,
include_data=True,
include_vectors=False,
filter=filter,
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in result:

View File

@ -148,6 +148,10 @@ class Vector:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
case VectorType.OPENGAUSS:
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
return OpenGaussFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -24,3 +24,4 @@ class VectorType(StrEnum):
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"
OPENGAUSS = "opengauss"

View File

@ -177,7 +177,11 @@ class VikingDBVector(BaseVector):
query_vector, limit=kwargs.get("top_k", 4)
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(results, score_threshold)
docs = self._get_search_res(results, score_threshold)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
return docs
def _get_search_res(self, results, score_threshold) -> list[Document]:
if len(results) == 0:

View File

@ -187,8 +187,10 @@ class WeaviateVector(BaseVector):
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
query_obj = query_obj.with_where(where_filter)
result = (
query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4))
@ -233,8 +235,10 @@ class WeaviateVector(BaseVector):
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
query_obj = query_obj.with_where(where_filter)
query_obj = query_obj.with_additional(["vector"])
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()

View File

@ -0,0 +1,45 @@
from collections.abc import Sequence
from typing import Literal, Optional
from pydantic import BaseModel, Field
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class Condition(BaseModel):
"""
Conditon detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataCondition(BaseModel):
"""
Metadata Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)

View File

@ -0,0 +1,15 @@
from enum import Enum
class BuiltInField(str, Enum):
document_name = "document_name"
uploader = "uploader"
upload_date = "upload_date"
last_update_date = "last_update_date"
source = "source"
class MetadataDataSource(Enum):
upload_file = "file_upload"
website_crawl = "website"
notion_import = "notion"

View File

@ -1,35 +1,61 @@
import json
import math
import re
import threading
from collections import Counter
from typing import Any, Optional, cast
from collections import Counter, defaultdict
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from sqlalchemy import Integer, and_, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.app_config.entities import (
DatasetEntity,
DatasetRetrieveConfigEntity,
MetadataFilteringCondition,
ModelConfig,
)
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rag.retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
@ -59,6 +85,7 @@ class DatasetRetrieval:
hit_callback: DatasetIndexToolCallbackHandler,
message_id: str,
memory: Optional[TokenBufferMemory] = None,
inputs: Optional[Mapping[str, Any]] = None,
) -> Optional[str]:
"""
Retrieve dataset.
@ -116,6 +143,22 @@ class DatasetRetrieval:
continue
available_datasets.append(dataset)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
else:
inputs = {}
available_datasets_ids = [dataset.id for dataset in available_datasets]
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
available_datasets_ids,
query,
tenant_id,
user_id,
retrieve_config.metadata_filtering_mode, # type: ignore
retrieve_config.metadata_model_config, # type: ignore
retrieve_config.metadata_filtering_conditions,
inputs,
)
all_documents = []
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
@ -130,6 +173,8 @@ class DatasetRetrieval:
model_config,
planning_strategy,
message_id,
metadata_filter_document_ids,
metadata_condition,
)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
all_documents = self.multiple_retrieve(
@ -146,6 +191,8 @@ class DatasetRetrieval:
retrieve_config.weights,
retrieve_config.reranking_enabled or True,
message_id,
metadata_filter_document_ids,
metadata_condition,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
@ -239,6 +286,8 @@ class DatasetRetrieval:
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
message_id: Optional[str] = None,
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
metadata_condition: Optional[MetadataCondition] = None,
):
tools = []
for dataset in available_datasets:
@ -279,6 +328,7 @@ class DatasetRetrieval:
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
document = Document(
@ -293,6 +343,15 @@ class DatasetRetrieval:
document.metadata["dataset_name"] = dataset.name
results.append(document)
else:
if metadata_condition and not metadata_filter_document_ids:
return []
document_ids_filter = None
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
return []
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
# get top k
@ -324,6 +383,7 @@ class DatasetRetrieval:
reranking_model=reranking_model,
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
@ -348,6 +408,8 @@ class DatasetRetrieval:
weights: Optional[dict[str, Any]] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
metadata_condition: Optional[MetadataCondition] = None,
):
if not available_datasets:
return []
@ -387,6 +449,16 @@ class DatasetRetrieval:
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
@ -395,6 +467,8 @@ class DatasetRetrieval:
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
},
)
threads.append(retrieval_thread)
@ -433,30 +507,33 @@ class DatasetRetrieval:
dataset_document = DatasetDocument.query.filter(
DatasetDocument.id == document.metadata["document_id"]
).first()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
).first()
if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
).first()
if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
db.session.commit()
else:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
db.session.commit()
else:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager | None = (
@ -490,7 +567,16 @@ class DatasetRetrieval:
db.session.add_all(dataset_queries)
db.session.commit()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
def _retriever(
self,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
document_ids_filter: Optional[list[str]] = None,
metadata_condition: Optional[MetadataCondition] = None,
):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
@ -503,6 +589,7 @@ class DatasetRetrieval:
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
document = Document(
@ -543,6 +630,7 @@ class DatasetRetrieval:
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
)
all_documents.extend(documents)
@ -730,3 +818,340 @@ class DatasetRetrieval:
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
)
return filter_documents[:top_k] if top_k else filter_documents
def _get_metadata_filter_condition(
self,
dataset_ids: list,
query: str,
tenant_id: str,
user_id: str,
metadata_filtering_mode: str,
metadata_model_config: ModelConfig,
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
inputs: dict,
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
filters = [] # type: ignore
metadata_condition = None
if metadata_filtering_mode == "disabled":
return None, None
elif metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(
dataset_ids, query, tenant_id, user_id, metadata_model_config
)
if automatic_metadata_filters:
conditions = []
for filter in automatic_metadata_filters:
self._process_metadata_filter_func(
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
filter.get("value"),
filters, # type: ignore
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore
conditions=conditions,
)
elif metadata_filtering_mode == "manual":
if metadata_filtering_conditions:
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
for condition in metadata_filtering_conditions.conditions: # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value or condition.comparison_operator in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
filters = self._process_metadata_filter_func(
condition.comparison_operator, metadata_name, expected_value, filters
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
if metadata_filtering_conditions.logical_operator == "or": # type: ignore
document_query = document_query.filter(or_(*filters))
else:
document_query = document_query.filter(and_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
def replacer(match):
key = match.group(1)
return str(inputs.get(key, f"{{{{{key}}}}}"))
pattern = re.compile(r"\{\{(\w+)\}\}")
return pattern.sub(replacer, text)
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
model_config=model_config,
mode=metadata_model_config.mode,
metadata_fields=all_metadata_fields,
query=query or "",
)
result_text = ""
try:
# handle invoke result
invoke_result = cast(
Generator[LLMResult, None, None],
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_config.parameters,
stop=stop,
stream=True,
user=user_id,
),
)
# handle invoke result
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
return None
return automatic_metadata_filters
def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list):
match condition:
case "contains":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
)
case "not contains":
filters.append(
(text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
key=metadata_name, value=f"%{value}%"
)
)
case "start with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
)
case "end with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
)
case "is" | "=":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
)
case "is not" | "":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
)
case "empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
case "after" | ">":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
case "" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
case _:
pass
return filters
def _fetch_model_config(
self, tenant_id: str, model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:return:
"""
if model is None:
raise ValueError("single_retrieval_config is required")
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ValueError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ValueError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
):
model_mode = ModelMode.value_of(mode)
input_text = query
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.CHAT:
prompt_template = []
system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
prompt_template.append(system_prompt_messages)
user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
prompt_template.append(user_prompt_message_1)
assistant_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_template.append(assistant_prompt_message_1)
user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
prompt_template.append(user_prompt_message_2)
assistant_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_template.append(assistant_prompt_message_2)
user_prompt_message_3 = ChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_template.append(user_prompt_message_3)
elif model_mode == ModelMode.COMPLETION:
prompt_template = CompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query=query or "",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
stop = model_config.stop
return prompt_messages, stop
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model = None
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage

View File

@ -0,0 +1,66 @@
METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501
METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which companys email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""
METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""
METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""
METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which companys email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501

View File

@ -76,38 +76,74 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
def recursive_split_text(self, text: str) -> list[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = self._separators[-1]
for _s in self._separators:
new_separators = []
for i, _s in enumerate(self._separators):
if _s == "":
separator = _s
break
if _s in text:
separator = _s
new_separators = self._separators[i + 1 :]
break
# Now that we have the separator, split the text
if separator:
splits = text.split(separator)
if separator == " ":
splits = text.split()
else:
splits = text.split(separator)
else:
splits = list(text)
# Now go merging things, recursively splitting longer texts.
splits = [s for s in splits if (s not in {"", "\n"})]
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits)
for s, s_len in zip(splits, s_lens):
if s_len < self._chunk_size:
_good_splits.append(s)
_good_splits_lengths.append(s_len)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)
final_chunks.extend(merged_text)
_good_splits = []
_good_splits_lengths = []
other_info = self.recursive_split_text(s)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)
final_chunks.extend(merged_text)
if _separator != "":
for s, s_len in zip(splits, s_lens):
if s_len < self._chunk_size:
_good_splits.append(s)
_good_splits_lengths.append(s_len)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)
_good_splits = []
_good_splits_lengths = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)
else:
current_part = ""
current_length = 0
overlap_part = ""
overlap_part_length = 0
for s, s_len in zip(splits, s_lens):
if current_length + s_len <= self._chunk_size - self._chunk_overlap:
current_part += s
current_length += s_len
elif current_length + s_len <= self._chunk_size:
current_part += s
current_length += s_len
overlap_part += s
overlap_part_length += s_len
else:
final_chunks.append(current_part)
current_part = overlap_part + s
current_length = s_len + overlap_part_length
overlap_part = ""
overlap_part_length = 0
if current_part:
final_chunks.append(current_part)
return final_chunks

View File

@ -1,25 +0,0 @@
# Tools
This module implements built-in tools used in Agent Assistants and Workflows within Dify. You could define and display your own tools in this module, without modifying the frontend logic. This decoupling allows for easier horizontal scaling of Dify's capabilities.
## Feature Introduction
The tools provided for Agents and Workflows are currently divided into two categories:
- `Built-in Tools` are internally implemented within our product and are hardcoded for use in Agents and Workflows.
- `Api-Based Tools` leverage third-party APIs for implementation. You don't need to code to integrate these -- simply provide interface definitions in formats like `OpenAPI` , `Swagger`, or the `OpenAI-plugin` on the front-end.
### Built-in Tool Providers
![Alt text](docs/images/index/image.png)
### API Tool Providers
![Alt text](docs/images/index/image-1.png)
## Tool Integration
To enable developers to build flexible and powerful tools, we provide two guides:
### [Quick Integration 👈🏻](./docs/en_US/tool_scale_out.md)
Quick integration aims at quickly getting you up to speed with tool integration by walking over an example Google Search tool.
### [Advanced Integration 👈🏻](./docs/en_US/advanced_scale_out.md)
Advanced integration will offer a deeper dive into the module interfaces, and explain how to implement more complex capabilities, such as generating images, combining multiple tools, and managing the flow of parameters, images, and files between different tools.

View File

@ -1,27 +0,0 @@
# Tools
该模块提供了各Agent和Workflow中会使用的内置工具的调用、鉴权接口并为 Dify 提供了统一的工具供应商的信息和凭据表单规则。
- 一方面将工具和业务代码解耦,方便开发者对模型横向扩展,
- 另一方面提供了只需在后端定义供应商和工具,即可在前端页面直接展示,无需修改前端逻辑。
## 功能介绍
对于给Agent和Workflow提供的工具我们当前将其分为两类
- `Built-in Tools` 内置工具即Dify内部实现的工具通过硬编码的方式提供给Agent和Workflow使用。
- `Api-Based Tools` 基于API的工具即通过调用第三方API实现的工具`Api-Based Tool`不需要再额外定义,只需提供`OpenAPI` `Swagger` `OpenAI plugin`等接口文档即可。
### 内置工具供应商
![Alt text](docs/images/index/image.png)
### API工具供应商
![Alt text](docs/images/index/image-1.png)
## 工具接入
为了实现更灵活更强大的功能Tools提供了一系列的接口帮助开发者快速构建想要的工具本文作为开发者的入门指南将会以[快速接入](./docs/zh_Hans/tool_scale_out.md)和[高级接入](./docs/zh_Hans/advanced_scale_out.md)两部分介绍如何接入工具。
### [快速接入 👈🏻](./docs/zh_Hans/tool_scale_out.md)
快速接入可以帮助你在10~20分钟内完成工具的接入但是这种接入方式只能实现简单的功能如果你想要实现更复杂的功能可以参考下面的高级接入。
### [高级接入 👈🏻](./docs/zh_Hans/advanced_scale_out.md)
高级接入将介绍如何实现更复杂的功能配置,包括实现图生图、实现多个工具的组合、实现参数、图片、文件在多个工具之间的流转。

View File

@ -1,31 +0,0 @@
# Tools
このモジュールは、Difyのエージェントアシスタントやワークフローで使用される組み込みツールを実装しています。このモジュールでは、フロントエンドのロジックを変更することなく、独自のツールを定義し表示することができます。この分離により、Difyの機能を容易に水平方向にスケールアウトできます。
## 機能紹介
エージェントとワークフロー向けに提供されるツールは、現在2つのカテゴリーに分類されています。
- `Built-in Tools`はDify内部で実装され、エージェントとワークフローで使用するためにハードコードされています。
- `Api-Based Tools`はサードパーティのAPIを利用して実装されています。これらを統合するためのコーディングは不要で、フロントエンドで
`OpenAPI`, `Swagger`または`OpenAI-plugin`などの形式でインターフェース定義を提供するだけです。
### 組み込みツールプロバイダー
![Alt text](docs/images/index/image.png)
### APIツールプロバイダー
![Alt text](docs/images/index/image-1.png)
## ツールの統合
開発者が柔軟で強力なツールを構築できるよう、2つのガイドを提供しています。
### [クイック統合 👈🏻](./docs/ja_JP/tool_scale_out.md)
クイック統合は、Google検索ツールの例を通じて、ツール統合の基本をすばやく理解できるようにすることを目的としています。
### [高度な統合 👈🏻](./docs/ja_JP/advanced_scale_out.md)
高度な統合では、モジュールインターフェースについてより深く掘り下げ、画像生成、複数ツールの組み合わせ、異なるツール間でのパラメーター、画像、ファイルのフロー管理など、より複雑な機能の実装方法を説明します。

View File

@ -179,6 +179,18 @@ class ApiTool(Tool):
for content_type in self.api_bundle.openapi["requestBody"]["content"]:
headers["Content-Type"] = content_type
body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"]
# handle ref schema
if "$ref" in body_schema:
ref_path = body_schema["$ref"].split("/")
ref_name = ref_path[-1]
if (
"components" in self.api_bundle.openapi
and "schemas" in self.api_bundle.openapi["components"]
):
if ref_name in self.api_bundle.openapi["components"]["schemas"]:
body_schema = self.api_bundle.openapi["components"]["schemas"][ref_name]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
@ -186,6 +198,8 @@ class ApiTool(Tool):
if property.get("format") == "binary":
f = parameters[name]
files.append((name, (f.filename, download(f), f.mime_type)))
elif "$ref" in property:
body[name] = parameters[name]
else:
# convert type
body[name] = self._convert_body_property_type(property, parameters[name])

View File

@ -1,278 +0,0 @@
# Advanced Tool Integration
Before starting with this advanced guide, please make sure you have a basic understanding of the tool integration process in Dify. Check out [Quick Integration](./tool_scale_out.md) for a quick runthrough.
## Tool Interface
We have defined a series of helper methods in the `Tool` class to help developers quickly build more complex tools.
### Message Return
Dify supports various message types such as `text`, `link`, `json`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces.
Please note, some parameters in the following interfaces will be introduced in later sections.
#### Image URL
You only need to pass the URL of the image, and Dify will automatically download the image and return it to the user.
```python
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:return: the image message
"""
```
#### Link
If you need to return a link, you can use the following interface.
```python
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:return: the link message
"""
```
#### Text
If you need to return a text message, you can use the following interface.
```python
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a text message
:param text: the text of the message
:return: the text message
"""
```
#### File BLOB
If you need to return the raw data of a file, such as images, audio, video, PPT, Word, Excel, etc., you can use the following interface.
- `blob` The raw data of the file, of bytes type
- `meta` The metadata of the file, if you know the type of the file, it is best to pass a `mime_type`, otherwise Dify will use `application/octet-stream` as the default type
```python
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:return: the blob message
"""
```
#### JSON
If you need to return a formatted JSON, you can use the following interface. This is commonly used for data transmission between nodes in a workflow, of course, in agent mode, most LLM are also able to read and understand JSON.
- `object` A Python dictionary object will be automatically serialized into JSON
```python
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
```
### Shortcut Tools
In large model applications, we have two common needs:
- First, summarize a long text in advance, and then pass the summary content to the LLM to prevent the original text from being too long for the LLM to handle
- The content obtained by the tool is a link, and the web page information needs to be crawled before it can be returned to the LLM
To help developers quickly implement these two needs, we provide the following two shortcut tools.
#### Text Summary Tool
This tool takes in an user_id and the text to be summarized, and returns the summarized text. Dify will use the default model of the current workspace to summarize the long text.
```python
def summary(self, user_id: str, content: str) -> str:
"""
summary the content
:param user_id: the user id
:param content: the content
:return: the summary
"""
```
#### Web Page Crawling Tool
This tool takes in web page link to be crawled and a user_agent (which can be empty), and returns a string containing the information of the web page. The `user_agent` is an optional parameter that can be used to identify the tool. If not passed, Dify will use the default `user_agent`.
```python
def get_url(self, url: str, user_agent: str = None) -> str:
"""
get url
""" the crawled result
```
### Variable Pool
We have introduced a variable pool in `Tool` to store variables, files, etc. generated during the tool's operation. These variables can be used by other tools during the tool's operation.
Next, we will use `DallE3` and `Vectorizer.AI` as examples to introduce how to use the variable pool.
- `DallE3` is an image generation tool that can generate images based on text. Here, we will let `DallE3` generate a logo for a coffee shop
- `Vectorizer.AI` is a vector image conversion tool that can convert images into vector images, so that the images can be infinitely enlarged without distortion. Here, we will convert the PNG icon generated by `DallE3` into a vector image, so that it can be truly used by designers.
#### DallE3
First, we use DallE3. After creating the image, we save the image to the variable pool. The code is as follows:
```python
from typing import Any, Dict, List, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
from base64 import b64decode
from openai import OpenAI
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
)
# prompt
prompt = tool_parameters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# call openapi dalle3
response = client.images.generate(
prompt=prompt, model='dall-e-3',
size='1024x1024', n=1, style='vivid', quality='standard',
response_format='b64_json'
)
result = []
for image in response.data:
# Save all images to the variable pool through the save_as parameter. The variable name is self.VARIABLE_KEY.IMAGE.value. If new images are generated later, they will overwrite the previous images.
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
return result
```
Note that we used `self.VARIABLE_KEY.IMAGE.value` as the variable name of the image. In order for developers' tools to cooperate with each other, we defined this `KEY`. You can use it freely, or you can choose not to use this `KEY`. Passing a custom KEY is also acceptable.
#### Vectorizer.AI
Next, we use Vectorizer.AI to convert the PNG icon generated by DallE3 into a vector image. Let's go through the functions we defined here. The code is as follows:
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool
"""
def get_runtime_parameters(self) -> List[ToolParameter]:
"""
Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list
"""
def is_tool_available(self) -> bool:
"""
Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here
"""
```
Next, let's implement these three functions
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key_name = self.runtime.credentials.get('api_key_name', None)
api_key_value = self.runtime.credentials.get('api_key_value', None)
if not api_key_name or not api_key_value:
raise ToolProviderCredentialValidationError('Please input api key name and value')
# Get image_id, the definition of image_id can be found in get_runtime_parameters
image_id = tool_parameters.get('image_id', '')
if not image_id:
return self.create_text_message('Please input image id')
# Get the image generated by DallE from the variable pool
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
if not image_binary:
return self.create_text_message('Image not found, please request user to generate image firstly.')
# Generate vector image
response = post(
'https://vectorizer.ai/api/v1/vectorize',
files={ 'image': image_binary },
data={ 'mode': 'test' },
auth=(api_key_name, api_key_value),
timeout=30
)
if response.status_code != 200:
raise Exception(response.text)
return [
self.create_text_message('the vectorized svg is saved as an image.'),
self.create_blob_message(blob=response.content,
meta={'mime_type': 'image/svg+xml'})
]
def get_runtime_parameters(self) -> List[ToolParameter]:
"""
override the runtime parameters
"""
# Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml.
return [
ToolParameter.get_simple_instance(
name='image_id',
llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}',
type=ToolParameter.ToolParameterType.SELECT,
required=True,
options=[i.name for i in self.list_default_image_variables()]
)
]
def is_tool_available(self) -> bool:
# Only when there are images in the variable pool, the LLM needs to use this tool
return len(self.list_default_image_variables()) > 0
```
It's worth noting that we didn't actually use `image_id` here. We assumed that there must be an image in the default variable pool when calling this tool, so we directly used `image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)` to get the image. In cases where the model's capabilities are weak, we recommend developers to do the same, which can effectively improve fault tolerance and avoid the model passing incorrect parameters.

View File

@ -1,248 +0,0 @@
# Quick Tool Integration
Here, we will use GoogleSearch as an example to demonstrate how to quickly integrate a tool.
## 1. Prepare the Tool Provider yaml
### Introduction
This yaml declares a new tool provider, and includes information like the provider's name, icon, author, and other details that are fetched by the frontend for display.
### Example
We need to create a `google` module (folder) under `core/tools/provider/builtin`, and create `google.yaml`. The name must be consistent with the module name.
Subsequently, all operations related to this tool will be carried out under this module.
```yaml
identity: # Basic information of the tool provider
author: Dify # Author
name: google # Name, unique, no duplication with other providers
label: # Label for frontend display
en_US: Google # English label
zh_Hans: Google # Chinese label
description: # Description for frontend display
en_US: Google # English description
zh_Hans: Google # Chinese description
icon: icon.svg # Icon, needs to be placed in the _assets folder of the current module
tags:
- search
```
- The `identity` field is mandatory, it contains the basic information of the tool provider, including author, name, label, description, icon, etc.
- The icon needs to be placed in the `_assets` folder of the current module, you can refer to [here](../../provider/builtin/google/_assets/icon.svg).
- The `tags` field is optional, it is used to classify the provider, and the frontend can filter the provider according to the tag, for all tags, they have been listed below:
```python
class ToolLabelEnum(Enum):
SEARCH = 'search'
IMAGE = 'image'
VIDEOS = 'videos'
WEATHER = 'weather'
FINANCE = 'finance'
DESIGN = 'design'
TRAVEL = 'travel'
SOCIAL = 'social'
NEWS = 'news'
MEDICAL = 'medical'
PRODUCTIVITY = 'productivity'
EDUCATION = 'education'
BUSINESS = 'business'
ENTERTAINMENT = 'entertainment'
UTILITIES = 'utilities'
OTHER = 'other'
```
## 2. Prepare Provider Credentials
Google, as a third-party tool, uses the API provided by SerpApi, which requires an API Key to use. This means that this tool needs a credential to use. For tools like `wikipedia`, there is no need to fill in the credential field, you can refer to [here](../../provider/builtin/wikipedia/wikipedia.yaml).
After configuring the credential field, the effect is as follows:
```yaml
identity:
author: Dify
name: google
label:
en_US: Google
zh_Hans: Google
description:
en_US: Google
zh_Hans: Google
icon: icon.svg
credentials_for_provider: # Credential field
serpapi_api_key: # Credential field name
type: secret-input # Credential field type
required: true # Required or not
label: # Credential field label
en_US: SerpApi API key # English label
zh_Hans: SerpApi API key # Chinese label
placeholder: # Credential field placeholder
en_US: Please input your SerpApi API key # English placeholder
zh_Hans: 请输入你的 SerpApi API key # Chinese placeholder
help: # Credential field help text
en_US: Get your SerpApi API key from SerpApi # English help text
zh_Hans: 从 SerpApi 获取您的 SerpApi API key # Chinese help text
url: https://serpapi.com/manage-api-key # Credential field help link
```
- `type`: Credential field type, currently can be either `secret-input`, `text-input`, or `select` , corresponding to password input box, text input box, and drop-down box, respectively. If set to `secret-input`, it will mask the input content on the frontend, and the backend will encrypt the input content.
## 3. Prepare Tool yaml
A provider can have multiple tools, each tool needs a yaml file to describe, this file contains the basic information, parameters, output, etc. of the tool.
Still taking GoogleSearch as an example, we need to create a `tools` module under the `google` module, and create `tools/google_search.yaml`, the content is as follows.
```yaml
identity: # Basic information of the tool
name: google_search # Tool name, unique, no duplication with other tools
author: Dify # Author
label: # Label for frontend display
en_US: GoogleSearch # English label
zh_Hans: 谷歌搜索 # Chinese label
description: # Description for frontend display
human: # Introduction for frontend display, supports multiple languages
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. # Introduction passed to LLM, in order to make LLM better understand this tool, we suggest to write as detailed information about this tool as possible here, so that LLM can understand and use this tool
parameters: # Parameter list
- name: query # Parameter name
type: string # Parameter type
required: true # Required or not
label: # Parameter label
en_US: Query string # English label
zh_Hans: 查询语句 # Chinese label
human_description: # Introduction for frontend display, supports multiple languages
en_US: used for searching
zh_Hans: 用于搜索网页内容
llm_description: key words for searching # Introduction passed to LLM, similarly, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
form: llm # Form type, llm means this parameter needs to be inferred by Agent, the frontend will not display this parameter
- name: result_type
type: select # Parameter type
required: true
options: # Drop-down box options
- value: text
label:
en_US: text
zh_Hans: 文本
- value: link
label:
en_US: link
zh_Hans: 链接
default: link
label:
en_US: Result type
zh_Hans: 结果类型
human_description:
en_US: used for selecting the result type, text or link
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
form: form # Form type, form means this parameter needs to be filled in by the user on the frontend before the conversation starts
```
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
- `parameters` Parameter list
- `name` (Mandatory) Parameter name, must be unique and not duplicate with other parameters.
- `type` (Mandatory) Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` five types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using the `secret-input` type
- `label` (Mandatory) Parameter label, for frontend display
- `form` (Mandatory) Form type, currently supports `llm`, `form` two types.
- In an agent app, `llm` indicates that the parameter is inferred by the LLM itself, while `form` indicates that the parameter can be pre-set for the tool.
- In a workflow app, both `llm` and `form` need to be filled out by the front end, but the parameters of `llm` will be used as input variables for the tool node.
- `required` Indicates whether the parameter is required or not
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
- `options` Parameter options
- In `llm` mode, Dify will pass all options to LLM, LLM can infer based on these options
- In `form` mode, when `type` is `select`, the frontend will display these options
- `default` Default value
- `min` Minimum value, can be set when the parameter type is `number`.
- `max` Maximum value, can be set when the parameter type is `number`.
- `placeholder` The prompt text for input boxes. It can be set when the form type is `form`, and the parameter type is `string`, `number`, or `secret-input`. It supports multiple languages.
- `human_description` Introduction for frontend display, supports multiple languages
- `llm_description` Introduction passed to LLM, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
## 4. Add Tool Logic
After completing the tool configuration, we can start writing the tool code that defines how it is invoked.
Create `google_search.py` under the `google/tools` module, the content is as follows.
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
class GoogleSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_parameters['query']
result_type = tool_parameters['result_type']
api_key = self.runtime.credentials['serpapi_api_key']
# Search with serpapi
result = SerpAPI(api_key).run(query, result_type=result_type)
if result_type == 'text':
return self.create_text_message(text=result)
return self.create_link_message(link=result)
```
### Parameters
The overall logic of the tool is in the `_invoke` method, this method accepts two parameters: `user_id` and `tool_parameters`, which represent the user ID and tool parameters respectively
### Return Data
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message. If you want to return multiple messages, you can use `[self.create_text_message('msg1'), self.create_text_message('msg2')]` to create a list of messages.
## 5. Add Provider Code
Finally, we need to create a provider class under the provider module to implement the provider's credential verification logic. If the credential verification fails, it will throw a `ToolProviderCredentialValidationError` exception.
Create `google.py` under the `google` module, the content is as follows.
```python
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
from typing import Any, Dict
class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
# 1. Here you need to instantiate a GoogleSearchTool with GoogleSearchTool(), it will automatically load the yaml configuration of GoogleSearchTool, but at this time it does not have credential information inside
# 2. Then you need to use the fork_tool_runtime method to pass the current credential information to GoogleSearchTool
# 3. Finally, invoke it, the parameters need to be passed according to the parameter rules configured in the yaml of GoogleSearchTool
GoogleSearchTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"query": "test",
"result_type": "link"
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
```
## Completion
After the above steps are completed, we can see this tool on the frontend, and it can be used in the Agent.
Of course, because google_search needs a credential, before using it, you also need to input your credentials on the frontend.
![Alt text](../images/index/image-2.png)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 242 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 407 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 266 KiB

View File

@ -1,283 +0,0 @@
# 高度なツール統合
このガイドを始める前に、Difyのツール統合プロセスの基本を理解していることを確認してください。簡単な概要については[クイック統合](./tool_scale_out.md)をご覧ください。
## ツールインターフェース
より複雑なツールを迅速に構築するのを支援するため、`Tool`クラスに一連のヘルパーメソッドを定義しています。
### メッセージの返却
Difyは`テキスト``リンク``画像``ファイルBLOB``JSON`などの様々なメッセージタイプをサポートしています。以下のインターフェースを通じて、異なるタイプのメッセージをLLMとユーザーに返すことができます。
注意:以下のインターフェースの一部のパラメータについては、後のセクションで説明します。
#### 画像URL
画像のURLを渡すだけで、Difyが自動的に画像をダウンロードしてユーザーに返します。
```python
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:param save_as: save as
:return: the image message
"""
```
#### リンク
リンクを返す必要がある場合は、以下のインターフェースを使用できます。
```python
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:param save_as: save as
:return: the link message
"""
```
#### テキスト
テキストメッセージを返す必要がある場合は、以下のインターフェースを使用できます。
```python
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a text message
:param text: the text of the message
:param save_as: save as
:return: the text message
"""
```
#### ファイルBLOB
画像、音声、動画、PPT、Word、Excelなどのファイルの生データを返す必要がある場合は、以下のインターフェースを使用できます。
- `blob` ファイルの生データbytes型
- `meta` ファイルのメタデータ。ファイルの種類が分かっている場合は、`mime_type`を渡すことをお勧めします。そうでない場合、Difyはデフォルトタイプとして`application/octet-stream`を使用します。
```python
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:param meta: meta
:param save_as: save as
:return: the blob message
"""
```
#### JSON
フォーマットされたJSONを返す必要がある場合は、以下のインターフェースを使用できます。これは通常、ワークフロー内のード間のデータ伝送に使用されますが、エージェントモードでは、ほとんどの大規模言語モデルもJSONを読み取り、理解することができます。
- `object` Pythonの辞書オブジェクトで、自動的にJSONにシリアライズされます。
```python
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
```
### ショートカットツール
大規模モデルアプリケーションでは、以下の2つの一般的なニーズがあります
- まず長いテキストを事前に要約し、その要約内容をLLMに渡すことで、元のテキストが長すぎてLLMが処理できない問題を防ぐ
- ツールが取得したコンテンツがリンクである場合、Webページ情報をクロールしてからLLMに返す必要がある
開発者がこれら2つのニーズを迅速に実装できるよう、以下の2つのショートカットツールを提供しています。
#### テキスト要約ツール
このツールはuser_idと要約するテキストを入力として受け取り、要約されたテキストを返します。Difyは現在のワークスペースのデフォルトモデルを使用して長文を要約します。
```python
def summary(self, user_id: str, content: str) -> str:
"""
summary the content
:param user_id: the user id
:param content: the content
:return: the summary
"""
```
#### Webページクローリングツール
このツールはクロールするWebページのリンクとユーザーエージェント空でも可を入力として受け取り、そのWebページの情報を含む文字列を返します。`user_agent`はオプションのパラメータで、ツールを識別するために使用できます。渡さない場合、Difyはデフォルトの`user_agent`を使用します。
```python
def get_url(self, url: str, user_agent: str = None) -> str:
"""
get url from the crawled result
"""
```
### 変数プール
`Tool`内に変数プールを導入し、ツールの実行中に生成された変数やファイルなどを保存します。これらの変数は、ツールの実行中に他のツールが使用することができます。
次に、`DallE3``Vectorizer.AI`を例に、変数プールの使用方法を紹介します。
- `DallE3`は画像生成ツールで、テキストに基づいて画像を生成できます。ここでは、`DallE3`にカフェのロゴを生成させます。
- `Vectorizer.AI`はベクター画像変換ツールで、画像をベクター画像に変換できるため、画像を無限に拡大しても品質が損なわれません。ここでは、`DallE3`が生成したPNGアイコンをベクター画像に変換し、デザイナーが実際に使用できるようにします。
#### DallE3
まず、DallE3を使用します。画像を作成した後、その画像を変数プールに保存します。コードは以下の通りです
```python
from typing import Any, Dict, List, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
from base64 import b64decode
from openai import OpenAI
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
)
# prompt
prompt = tool_parameters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# call openapi dalle3
response = client.images.generate(
prompt=prompt, model='dall-e-3',
size='1024x1024', n=1, style='vivid', quality='standard',
response_format='b64_json'
)
result = []
for image in response.data:
# Save all images to the variable pool through the save_as parameter. The variable name is self.VARIABLE_KEY.IMAGE.value. If new images are generated later, they will overwrite the previous images.
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
return result
```
ここでは画像の変数名として`self.VARIABLE_KEY.IMAGE.value`を使用していることに注意してください。開発者のツールが互いに連携できるよう、この`KEY`を定義しました。自由に使用することも、この`KEY`を使用しないこともできます。カスタムのKEYを渡すこともできます。
#### Vectorizer.AI
次に、Vectorizer.AIを使用して、DallE3が生成したPNGアイコンをベクター画像に変換します。ここで定義した関数を見てみましょう。コードは以下の通りです
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any])
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool
"""
def get_runtime_parameters(self) -> List[ToolParameter]:
"""
Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list
"""
def is_tool_available(self) -> bool:
"""
Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here
"""
```
次に、これら3つの関数を実装します
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any])
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key_name = self.runtime.credentials.get('api_key_name', None)
api_key_value = self.runtime.credentials.get('api_key_value', None)
if not api_key_name or not api_key_value:
raise ToolProviderCredentialValidationError('Please input api key name and value')
# Get image_id, the definition of image_id can be found in get_runtime_parameters
image_id = tool_parameters.get('image_id', '')
if not image_id:
return self.create_text_message('Please input image id')
# Get the image generated by DallE from the variable pool
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
if not image_binary:
return self.create_text_message('Image not found, please request user to generate image firstly.')
# Generate vector image
response = post(
'https://vectorizer.ai/api/v1/vectorize',
files={ 'image': image_binary },
data={ 'mode': 'test' },
auth=(api_key_name, api_key_value),
timeout=30
)
if response.status_code != 200:
raise Exception(response.text)
return [
self.create_text_message('the vectorized svg is saved as an image.'),
self.create_blob_message(blob=response.content,
meta={'mime_type': 'image/svg+xml'})
]
def get_runtime_parameters(self) -> List[ToolParameter]:
"""
override the runtime parameters
"""
# Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml.
return [
ToolParameter.get_simple_instance(
name='image_id',
llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}',
type=ToolParameter.ToolParameterType.SELECT,
required=True,
options=[i.name for i in self.list_default_image_variables()]
)
]
def is_tool_available(self) -> bool:
# Only when there are images in the variable pool, the LLM needs to use this tool
return len(self.list_default_image_variables()) > 0
```
ここで注目すべきは、実際には`image_id`を使用していないことです。このツールを呼び出す際には、デフォルトの変数プールに必ず画像があると仮定し、直接`image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)`を使用して画像を取得しています。モデルの能力が弱い場合、開発者にもこの方法を推奨します。これにより、エラー許容度を効果的に向上させ、モデルが誤ったパラメータを渡すのを防ぐことができます。

View File

@ -1,240 +0,0 @@
# ツールの迅速な統合
ここでは、GoogleSearchを例にツールを迅速に統合する方法を紹介します。
## 1. ツールプロバイダーのyamlを準備する
### 概要
このyamlファイルには、プロバイダー名、アイコン、作者などの詳細情報が含まれ、フロントエンドでの柔軟な表示を可能にします。
### 例
`core/tools/provider/builtin`の下に`google`モジュール(フォルダ)を作成し、`google.yaml`を作成します。名前はモジュール名と一致している必要があります。
以降、このツールに関するすべての操作はこのモジュール内で行います。
```yaml
identity: # ツールプロバイダーの基本情報
author: Dify # 作者
name: google # 名前(一意、他のプロバイダーと重複不可)
label: # フロントエンド表示用のラベル
en_US: Google # 英語ラベル
zh_Hans: Google # 中国語ラベル
description: # フロントエンド表示用の説明
en_US: Google # 英語説明
zh_Hans: Google # 中国語説明
icon: icon.svg # アイコン現在のモジュールの_assetsフォルダに配置
tags: # タグ(フロントエンド表示用)
- search
```
- `identity`フィールドは必須で、ツールプロバイダーの基本情報(作者、名前、ラベル、説明、アイコンなど)が含まれます。
- アイコンは現在のモジュールの`_assets`フォルダに配置する必要があります。[こちら](../../provider/builtin/google/_assets/icon.svg)を参照してください。
- タグはフロントエンドでの表示に使用され、ユーザーがこのツールプロバイダーを素早く見つけるのに役立ちます。現在サポートされているすべてのタグは以下の通りです:
```python
class ToolLabelEnum(Enum):
SEARCH = 'search'
IMAGE = 'image'
VIDEOS = 'videos'
WEATHER = 'weather'
FINANCE = 'finance'
DESIGN = 'design'
TRAVEL = 'travel'
SOCIAL = 'social'
NEWS = 'news'
MEDICAL = 'medical'
PRODUCTIVITY = 'productivity'
EDUCATION = 'education'
BUSINESS = 'business'
ENTERTAINMENT = 'entertainment'
UTILITIES = 'utilities'
OTHER = 'other'
```
## 2. プロバイダーの認証情報を準備する
GoogleはSerpApiが提供するAPIを使用するサードパーティツールであり、SerpApiを使用するにはAPI Keyが必要です。つまり、このツールを使用するには認証情報が必要です。一方、`wikipedia`のようなツールでは認証情報フィールドを記入する必要はありません。[こちら](../../provider/builtin/wikipedia/wikipedia.yaml)を参照してください。
認証情報フィールドを設定すると、以下のようになります:
```yaml
identity:
author: Dify
name: google
label:
en_US: Google
zh_Hans: Google
description:
en_US: Google
zh_Hans: Google
icon: icon.svg
credentials_for_provider: # 認証情報フィールド
serpapi_api_key: # 認証情報フィールド名
type: secret-input # 認証情報フィールドタイプ
required: true # 必須かどうか
label: # 認証情報フィールドラベル
en_US: SerpApi API key # 英語ラベル
zh_Hans: SerpApi API key # 中国語ラベル
placeholder: # 認証情報フィールドプレースホルダー
en_US: Please input your SerpApi API key # 英語プレースホルダー
zh_Hans: 请输入你的 SerpApi API key # 中国語プレースホルダー
help: # 認証情報フィールドヘルプテキスト
en_US: Get your SerpApi API key from SerpApi # 英語ヘルプテキスト
zh_Hans: 从 SerpApi 获取您的 SerpApi API key # 中国語ヘルプテキスト
url: https://serpapi.com/manage-api-key # 認証情報フィールドヘルプリンク
```
- `type`:認証情報フィールドタイプ。現在、`secret-input`、`text-input`、`select`の3種類をサポートしており、それぞれパスワード入力ボックス、テキスト入力ボックス、ドロップダウンボックスに対応します。`secret-input`の場合、フロントエンドで入力内容が隠され、バックエンドで入力内容が暗号化されます。
## 3. ツールのyamlを準備する
1つのプロバイダーの下に複数のツールを持つことができ、各ツールにはyamlファイルが必要です。このファイルにはツールの基本情報、パラメータ、出力などが含まれます。
引き続きGoogleSearchを例に、`google`モジュールの下に`tools`モジュールを作成し、`tools/google_search.yaml`を作成します。内容は以下の通りです:
```yaml
identity: # ツールの基本情報
name: google_search # ツール名(一意、他のツールと重複不可)
author: Dify # 作者
label: # フロントエンド表示用のラベル
en_US: GoogleSearch # 英語ラベル
zh_Hans: 谷歌搜索 # 中国語ラベル
description: # フロントエンド表示用の説明
human: # フロントエンド表示用の紹介(多言語対応)
en_US: A tool for performing a Google SERP search and extracting snippets and webpages. Input should be a search query.
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
llm: A tool for performing a Google SERP search and extracting snippets and webpages. Input should be a search query. # LLMに渡す紹介文。LLMがこのツールをより理解できるよう、できるだけ詳細な情報を記述することをお勧めします。
parameters: # パラメータリスト
- name: query # パラメータ名
type: string # パラメータタイプ
required: true # 必須かどうか
label: # パラメータラベル
en_US: Query string # 英語ラベル
zh_Hans: 查询语句 # 中国語ラベル
human_description: # フロントエンド表示用の紹介(多言語対応)
en_US: used for searching
zh_Hans: 用于搜索网页内容
llm_description: key words for searching # LLMに渡す紹介文。LLMがこのパラメータをより理解できるよう、できるだけ詳細な情報を記述することをお勧めします。
form: llm # フォームタイプ。llmはこのパラメータがAgentによって推論される必要があることを意味し、フロントエンドではこのパラメータは表示されません。
- name: result_type
type: select # パラメータタイプ
required: true
options: # ドロップダウンボックスのオプション
- value: text
label:
en_US: text
zh_Hans: 文本
- value: link
label:
en_US: link
zh_Hans: 链接
default: link
label:
en_US: Result type
zh_Hans: 结果类型
human_description:
en_US: used for selecting the result type, text or link
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
form: form # フォームタイプ。formはこのパラメータが対話開始前にフロントエンドでユーザーによって入力される必要があることを意味します。
```
- `identity`フィールドは必須で、ツールの基本情報(名前、作者、ラベル、説明など)が含まれます。
- `parameters` パラメータリスト
- `name`(必須)パラメータ名。一意で、他のパラメータと重複しないようにしてください。
- `type`(必須)パラメータタイプ。現在、`string`、`number`、`boolean`、`select`、`secret-input`の5種類をサポートしており、それぞれ文字列、数値、ブール値、ドロップダウンボックス、暗号化入力ボックスに対応します。機密情報には`secret-input`タイプの使用をお勧めします。
- `label`(必須)パラメータラベル。フロントエンド表示用です。
- `form`(必須)フォームタイプ。現在、`llm`と`form`の2種類をサポートしています。
- エージェントアプリケーションでは、`llm`はこのパラメータがLLM自身によって推論されることを示し、`form`はこのツールを使用するために事前に設定できるパラメータであることを示します。
- ワークフローアプリケーションでは、`llm`と`form`の両方がフロントエンドで入力する必要がありますが、`llm`のパラメータはツールノードの入力変数として使用されます。
- `required` パラメータが必須かどうかを示します。
- `llm`モードでは、パラメータが必須の場合、Agentはこのパラメータを推論する必要があります。
- `form`モードでは、パラメータが必須の場合、ユーザーは対話開始前にフロントエンドでこのパラメータを入力する必要があります。
- `options` パラメータオプション
- `llm`モードでは、DifyはすべてのオプションをLLMに渡し、LLMはこれらのオプションに基づいて推論できます。
- `form`モードで、`type`が`select`の場合、フロントエンドはこれらのオプションを表示します。
- `default` デフォルト値
- `min` 最小値。パラメータタイプが`number`の場合に設定できます。
- `max` 最大値。パラメータタイプが`number`の場合に設定できます。
- `human_description` フロントエンド表示用の紹介。多言語対応です。
- `placeholder` 入力ボックスのプロンプトテキスト。フォームタイプが`form`で、パラメータタイプが`string`、`number`、`secret-input`の場合に設定できます。多言語対応です。
- `llm_description` LLMに渡す紹介文。LLMがこのパラメータをより理解できるよう、できるだけ詳細な情報を記述することをお勧めします。
## 4. ツールコードを準備する
ツールの設定が完了したら、ツールのロジックを実装するコードを作成します。
`google/tools`モジュールの下に`google_search.py`を作成し、内容は以下の通りです:
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
class GoogleSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
ツールを呼び出す
"""
query = tool_parameters['query']
result_type = tool_parameters['result_type']
api_key = self.runtime.credentials['serpapi_api_key']
result = SerpAPI(api_key).run(query, result_type=result_type)
if result_type == 'text':
return self.create_text_message(text=result)
return self.create_link_message(link=result)
```
### パラメータ
ツールの全体的なロジックは`_invoke`メソッドにあります。このメソッドは2つのパラメータ`user_id`とtool_parameters`を受け取り、それぞれユーザーIDとツールパラメータを表します。
### 戻り値
ツールの戻り値として、1つのメッセージまたは複数のメッセージを選択できます。ここでは1つのメッセージを返しています。`create_text_message``create_link_message`を使用して、テキストメッセージまたはリンクメッセージを作成できます。複数のメッセージを返す場合は、リストを構築できます(例:`[self.create_text_message('msg1'), self.create_text_message('msg2')]`)。
## 5. プロバイダーコードを準備する
最後に、プロバイダーモジュールの下にプロバイダークラスを作成し、プロバイダーの認証情報検証ロジックを実装する必要があります。認証情報の検証が失敗した場合、`ToolProviderCredentialValidationError`例外が発生します。
`google`モジュールの下に`google.py`を作成し、内容は以下の通りです:
```python
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
from typing import Any, Dict
class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
# 1. ここでGoogleSearchTool()を使ってGoogleSearchToolをインスタンス化する必要があります。これによりGoogleSearchToolのyaml設定が自動的に読み込まれますが、この時点では認証情報は含まれていません
# 2. 次に、fork_tool_runtimeメソッドを使用して、現在の認証情報をGoogleSearchToolに渡す必要があります
# 3. 最後に、invokeを呼び出します。パラメータはGoogleSearchToolのyamlで設定されたパラメータルールに従って渡す必要があります
GoogleSearchTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"query": "test",
"result_type": "link"
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
```
## 完了
以上のステップが完了すると、このツールをフロントエンドで確認し、Agentで使用することができるようになります。
もちろん、google_searchには認証情報が必要なため、使用する前にフロントエンドで認証情報を入力する必要があります。
![Alt text](../images/index/image-2.png)

View File

@ -1,283 +0,0 @@
# 高级接入Tool
在开始高级接入之前,请确保你已经阅读过[快速接入](./tool_scale_out.md)并对Dify的工具接入流程有了基本的了解。
## 工具接口
我们在`Tool`类中定义了一系列快捷方法,用于帮助开发者快速构较为复杂的工具
### 消息返回
Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型你可以通过以下几个接口返回不同类型的消息给LLM和用户。
注意,在下面的接口中的部分参数将在后面的章节中介绍。
#### 图片URL
只需要传递图片的URL即可Dify会自动下载图片并返回给用户。
```python
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:param save_as: save as
:return: the image message
"""
```
#### 链接
如果你需要返回一个链接,可以使用以下接口。
```python
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:param save_as: save as
:return: the link message
"""
```
#### 文本
如果你需要返回一个文本消息,可以使用以下接口。
```python
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a text message
:param text: the text of the message
:param save_as: save as
:return: the text message
"""
```
#### 文件BLOB
如果你需要返回文件的原始数据如图片、音频、视频、PPT、Word、Excel等可以使用以下接口。
- `blob` 文件的原始数据bytes类型
- `meta` 文件的元数据,如果你知道该文件的类型,最好传递一个`mime_type`否则Dify将使用`application/octet-stream`作为默认类型
```python
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:param meta: meta
:param save_as: save as
:return: the blob message
"""
```
#### JSON
如果你需要返回一个格式化的JSON可以使用以下接口。这通常用于workflow中的节点间的数据传递当然agent模式中大部分大模型也都能够阅读和理解JSON。
- `object` 一个Python的字典对象会被自动序列化为JSON
```python
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
```
### 快捷工具
在大模型应用中,我们有两种常见的需求:
- 先将很长的文本进行提前总结然后再将总结内容传递给LLM以防止原文本过长导致LLM无法处理
- 工具获取到的内容是一个链接需要爬取网页信息后再返回给LLM
为了帮助开发者快速实现这两种需求,我们提供了以下两个快捷工具。
#### 文本总结工具
该工具需要传入user_id和需要进行总结的文本返回一个总结后的文本Dify会使用当前工作空间的默认模型对长文本进行总结。
```python
def summary(self, user_id: str, content: str) -> str:
"""
summary the content
:param user_id: the user id
:param content: the content
:return: the summary
"""
```
#### 网页爬取工具
该工具需要传入需要爬取的网页链接和一个user_agent可为空返回一个包含该网页信息的字符串其中`user_agent`是可选参数可以用来识别工具如果不传递Dify将使用默认的`user_agent`
```python
def get_url(self, url: str, user_agent: str = None) -> str:
"""
get url from the crawled result
"""
```
### 变量池
我们在`Tool`中引入了一个变量池,用于存储工具运行过程中产生的变量、文件等,这些变量可以在工具运行过程中被其他工具使用。
下面,我们以`DallE3``Vectorizer.AI`为例,介绍如何使用变量池。
- `DallE3`是一个图片生成工具,它可以根据文本生成图片,在这里,我们将让`DallE3`生成一个咖啡厅的Logo
- `Vectorizer.AI`是一个矢量图转换工具,它可以将图片转换为矢量图,使得图片可以无限放大而不失真,在这里,我们将`DallE3`生成的PNG图标转换为矢量图从而可以真正被设计师使用。
#### DallE3
首先我们使用DallE3在创建完图片以后我们将图片保存到变量池中代码如下
```python
from typing import Any, Dict, List, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
from base64 import b64decode
from openai import OpenAI
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
)
# prompt
prompt = tool_parameters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# call openapi dalle3
response = client.images.generate(
prompt=prompt, model='dall-e-3',
size='1024x1024', n=1, style='vivid', quality='standard',
response_format='b64_json'
)
result = []
for image in response.data:
# 将所有图片通过save_as参数保存到变量池中变量名为self.VARIABLE_KEY.IMAGE.value如果如果后续有新的图片生成那么将会覆盖之前的图片
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
return result
```
我们可以注意到这里我们使用了`self.VARIABLE_KEY.IMAGE.value`作为图片的变量名,为了便于开发者们的工具能够互相配合,我们定义了这个`KEY`,大家可以自由使用,也可以不使用这个`KEY`传递一个自定义的KEY也是可以的。
#### Vectorizer.AI
接下来我们使用Vectorizer.AI将DallE3生成的PNG图标转换为矢量图我们先来过一遍我们在这里定义的函数代码如下
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
工具调用,图片变量名需要从这里传递进来,从而我们就可以从变量池中获取到图片
"""
def get_runtime_parameters(self) -> List[ToolParameter]:
"""
重写工具参数列表我们可以根据当前变量池里的实际情况来动态生成参数列表从而LLM可以根据参数列表来生成表单
"""
def is_tool_available(self) -> bool:
"""
当前工具是否可用如果当前变量池中没有图片那么我们就不需要展示这个工具这里返回False即可
"""
```
接下来我们来实现这三个函数
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key_name = self.runtime.credentials.get('api_key_name', None)
api_key_value = self.runtime.credentials.get('api_key_value', None)
if not api_key_name or not api_key_value:
raise ToolProviderCredentialValidationError('Please input api key name and value')
# 获取image_idimage_id的定义可以在get_runtime_parameters中找到
image_id = tool_parameters.get('image_id', '')
if not image_id:
return self.create_text_message('Please input image id')
# 从变量池中获取到之前DallE生成的图片
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
if not image_binary:
return self.create_text_message('Image not found, please request user to generate image firstly.')
# 生成矢量图
response = post(
'https://vectorizer.ai/api/v1/vectorize',
files={ 'image': image_binary },
data={ 'mode': 'test' },
auth=(api_key_name, api_key_value),
timeout=30
)
if response.status_code != 200:
raise Exception(response.text)
return [
self.create_text_message('the vectorized svg is saved as an image.'),
self.create_blob_message(blob=response.content,
meta={'mime_type': 'image/svg+xml'})
]
def get_runtime_parameters(self) -> List[ToolParameter]:
"""
override the runtime parameters
"""
# 这里我们重写了工具参数列表定义了image_id并设置了它的选项列表为当前变量池中的所有图片这里的配置与yaml中的配置是一致的
return [
ToolParameter.get_simple_instance(
name='image_id',
llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}',
type=ToolParameter.ToolParameterType.SELECT,
required=True,
options=[i.name for i in self.list_default_image_variables()]
)
]
def is_tool_available(self) -> bool:
# 只有当变量池中有图片时LLM才需要使用这个工具
return len(self.list_default_image_variables()) > 0
```
可以注意到的是,我们这里其实并没有使用到`image_id`,我们已经假设了调用这个工具的时候一定有一张图片在默认的变量池中,所以直接使用了`image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)`来获取图片,在模型能力较弱的情况下,我们建议开发者们也这样做,可以有效提升容错率,避免模型传递错误的参数。

View File

@ -1,237 +0,0 @@
# 快速接入Tool
这里我们以GoogleSearch为例介绍如何快速接入一个工具。
## 1. 准备工具供应商yaml
### 介绍
这个yaml将包含工具供应商的信息包括供应商名称、图标、作者等详细信息以帮助前端灵活展示。
### 示例
我们需要在 `core/tools/provider/builtin`下创建一个`google`模块(文件夹),并创建`google.yaml`,名称必须与模块名称一致。
后续,我们关于这个工具的所有操作都将在这个模块下进行。
```yaml
identity: # 工具供应商的基本信息
author: Dify # 作者
name: google # 名称,唯一,不允许和其他供应商重名
label: # 标签,用于前端展示
en_US: Google # 英文标签
zh_Hans: Google # 中文标签
description: # 描述,用于前端展示
en_US: Google # 英文描述
zh_Hans: Google # 中文描述
icon: icon.svg # 图标需要放置在当前模块的_assets文件夹下
tags: # 标签,用于前端展示
- search
```
- `identity` 字段是必须的,它包含了工具供应商的基本信息,包括作者、名称、标签、描述、图标等
- 图标需要放置在当前模块的`_assets`文件夹下,可以参考[这里](../../provider/builtin/google/_assets/icon.svg)。
- 标签用于前端展示,可以帮助用户快速找到这个工具供应商,下面列出了目前所支持的所有标签
```python
class ToolLabelEnum(Enum):
SEARCH = 'search'
IMAGE = 'image'
VIDEOS = 'videos'
WEATHER = 'weather'
FINANCE = 'finance'
DESIGN = 'design'
TRAVEL = 'travel'
SOCIAL = 'social'
NEWS = 'news'
MEDICAL = 'medical'
PRODUCTIVITY = 'productivity'
EDUCATION = 'education'
BUSINESS = 'business'
ENTERTAINMENT = 'entertainment'
UTILITIES = 'utilities'
OTHER = 'other'
```
## 2. 准备供应商凭据
Google作为一个第三方工具使用了SerpApi提供的API而SerpApi需要一个API Key才能使用那么就意味着这个工具需要一个凭据才可以使用而像`wikipedia`这样的工具,就不需要填写凭据字段,可以参考[这里](../../provider/builtin/wikipedia/wikipedia.yaml)。
配置好凭据字段后效果如下:
```yaml
identity:
author: Dify
name: google
label:
en_US: Google
zh_Hans: Google
description:
en_US: Google
zh_Hans: Google
icon: icon.svg
credentials_for_provider: # 凭据字段
serpapi_api_key: # 凭据字段名称
type: secret-input # 凭据字段类型
required: true # 是否必填
label: # 凭据字段标签
en_US: SerpApi API key # 英文标签
zh_Hans: SerpApi API key # 中文标签
placeholder: # 凭据字段占位符
en_US: Please input your SerpApi API key # 英文占位符
zh_Hans: 请输入你的 SerpApi API key # 中文占位符
help: # 凭据字段帮助文本
en_US: Get your SerpApi API key from SerpApi # 英文帮助文本
zh_Hans: 从 SerpApi 获取您的 SerpApi API key # 中文帮助文本
url: https://serpapi.com/manage-api-key # 凭据字段帮助链接
```
- `type`:凭据字段类型,目前支持`secret-input`、`text-input`、`select` 三种类型,分别对应密码输入框、文本输入框、下拉框,如果为`secret-input`,则会在前端隐藏输入内容,并且后端会对输入内容进行加密。
## 3. 准备工具yaml
一个供应商底下可以有多个工具每个工具都需要一个yaml文件来描述这个文件包含了工具的基本信息、参数、输出等。
仍然以GoogleSearch为例我们需要在`google`模块下创建一个`tools`模块,并创建`tools/google_search.yaml`,内容如下。
```yaml
identity: # 工具的基本信息
name: google_search # 工具名称,唯一,不允许和其他工具重名
author: Dify # 作者
label: # 标签,用于前端展示
en_US: GoogleSearch # 英文标签
zh_Hans: 谷歌搜索 # 中文标签
description: # 描述,用于前端展示
human: # 用于前端展示的介绍,支持多语言
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. # 传递给LLM的介绍为了使得LLM更好理解这个工具我们建议在这里写上关于这个工具尽可能详细的信息让LLM能够理解并使用这个工具
parameters: # 参数列表
- name: query # 参数名称
type: string # 参数类型
required: true # 是否必填
label: # 参数标签
en_US: Query string # 英文标签
zh_Hans: 查询语句 # 中文标签
human_description: # 用于前端展示的介绍,支持多语言
en_US: used for searching
zh_Hans: 用于搜索网页内容
llm_description: key words for searching # 传递给LLM的介绍同上为了使得LLM更好理解这个参数我们建议在这里写上关于这个参数尽可能详细的信息让LLM能够理解这个参数
form: llm # 表单类型llm表示这个参数需要由Agent自行推理出来前端将不会展示这个参数
- name: result_type
type: select # 参数类型
required: true
options: # 下拉框选项
- value: text
label:
en_US: text
zh_Hans: 文本
- value: link
label:
en_US: link
zh_Hans: 链接
default: link
label:
en_US: Result type
zh_Hans: 结果类型
human_description:
en_US: used for selecting the result type, text or link
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
form: form # 表单类型form表示这个参数需要由用户在对话开始前在前端填写
```
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
- `parameters` 参数列表
- `name` (必填)参数名称,唯一,不允许和其他参数重名
- `type` (必填)参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
- `label`(必填)参数标签,用于前端展示
- `form` (必填)表单类型,目前支持`llm`、`form`两种类型
- 在Agent应用中`llm`表示该参数LLM自行推理`form`表示要使用该工具可提前设定的参数
- 在workflow应用中`llm`和`form`均需要前端填写,但`llm`的参数会做为工具节点的输入变量
- `required` 是否必填
- 在`llm`模式下如果参数为必填则会要求Agent必须要推理出这个参数
- 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
- `options` 参数选项
- 在`llm`模式下Dify会将所有选项传递给LLMLLM可以根据这些选项进行推理
- 在`form`模式下,`type`为`select`时,前端会展示这些选项
- `default` 默认值
- `min` 最小值,当参数类型为`number`时可以设定
- `max` 最大值,当参数类型为`number`时可以设定
- `human_description` 用于前端展示的介绍,支持多语言
- `placeholder` 字段输入框的提示文字,在表单类型为`form`,参数类型为`string`、`number`、`secret-input`时,可以设定,支持多语言
- `llm_description` 传递给LLM的介绍为了使得LLM更好理解这个参数我们建议在这里写上关于这个参数尽可能详细的信息让LLM能够理解这个参数
## 4. 准备工具代码
当完成工具的配置以后,我们就可以开始编写工具代码了,主要用于实现工具的逻辑。
在`google/tools`模块下创建`google_search.py`,内容如下。
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
class GoogleSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_parameters['query']
result_type = tool_parameters['result_type']
api_key = self.runtime.credentials['serpapi_api_key']
result = SerpAPI(api_key).run(query, result_type=result_type)
if result_type == 'text':
return self.create_text_message(text=result)
return self.create_link_message(link=result)
```
### 参数
工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id`和`tool_parameters`分别表示用户ID和工具参数
### 返回数据
在工具返回时,你可以选择返回一条消息或者多个消息,这里我们返回一条消息,使用`create_text_message`和`create_link_message`可以创建一条文本消息或者一条链接消息。如需返回多条消息,可以使用列表构建,例如`[self.create_text_message('msg1'), self.create_text_message('msg2')]`
## 5. 准备供应商代码
最后,我们需要在供应商模块下创建一个供应商类,用于实现供应商的凭据验证逻辑,如果凭据验证失败,将会抛出`ToolProviderCredentialValidationError`异常。
在`google`模块下创建`google.py`,内容如下。
```python
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
from typing import Any, Dict
class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
# 1. 此处需要使用GoogleSearchTool()实例化一个GoogleSearchTool它会自动加载GoogleSearchTool的yaml配置但是此时它内部没有凭据信息
# 2. 随后需要使用fork_tool_runtime方法将当前的凭据信息传递给GoogleSearchTool
# 3. 最后invoke即可参数需要根据GoogleSearchTool的yaml中配置的参数规则进行传递
GoogleSearchTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"query": "test",
"result_type": "link"
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
```
## 完成
当上述步骤完成以后我们就可以在前端看到这个工具了并且可以在Agent中使用这个工具。
当然因为google_search需要一个凭据在使用之前还需要在前端配置它的凭据。
![Alt text](../images/index/image-2.png)

View File

@ -63,11 +63,18 @@ class ToolFileManager:
conversation_id: Optional[str],
file_binary: bytes,
mimetype: str,
filename: Optional[str] = None,
) -> ToolFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
unique_filename = f"{unique_name}{extension}"
# default just as before
present_filename = unique_filename
if filename is not None:
has_extension = len(filename.split(".")) > 1
# Add extension flexibly
present_filename = filename if has_extension else f"{filename}{extension}"
filepath = f"tools/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
tool_file = ToolFile(
@ -76,7 +83,7 @@ class ToolFileManager:
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
name=filename,
name=present_filename,
size=len(file_binary),
)

View File

@ -765,17 +765,22 @@ class ToolManager:
@classmethod
def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
return (
dify_config.CONSOLE_API_URL
+ "/console/api/workspaces/current/tool-provider/builtin/"
+ provider_id
+ "/icon"
return str(
URL(dify_config.CONSOLE_API_URL or "/")
/ "console"
/ "api"
/ "workspaces"
/ "current"
/ "tool-provider"
/ "builtin"
/ provider_id
/ "icon"
)
@classmethod
def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
return str(
URL(dify_config.CONSOLE_API_URL)
URL(dify_config.CONSOLE_API_URL or "/")
/ "console"
/ "api"
/ "workspaces"

View File

@ -59,6 +59,8 @@ class ToolFileMessageTransformer:
meta = message.meta or {}
mimetype = meta.get("mime_type", "application/octet-stream")
# get filename from meta
filename = meta.get("file_name", None)
# if message is str, encode it to bytes
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
@ -72,6 +74,7 @@ class ToolFileMessageTransformer:
conversation_id=conversation_id,
file_binary=message.message.blob,
mimetype=mimetype,
filename=filename,
)
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))

Some files were not shown because too many files have changed in this diff Show More