Compare commits

..

424 Commits

Author SHA1 Message Date
967c9081cc Resolve conflicts: Force the main code to overwrite everything in deploy/dev 2026-01-13 19:46:59 +08:00
a22cc5bc5e chore: Bump Dify version to 1.11.3 (#30903) 2026-01-13 17:49:13 +08:00
yyh
1fbdf6b465 refactor(web): setup status caching (#30798) 2026-01-13 16:59:49 +08:00
491e1fd6a4 chore: case insensitive email (#29978)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-01-13 15:42:44 +08:00
0e33dfb5c2 fix: In the LLM model in dify, when a message is added, the first cli… (#29540)
Co-authored-by: 青枕 <qingzhen.ww@alibaba-inc.com>
2026-01-13 15:42:32 +08:00
lif
ea708e7a32 fix(web): add null check for SSE stream bufferObj to prevent TypeError (#30131)
Signed-off-by: majiayu000 <1835304752@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-13 15:40:43 +08:00
c09e29c3f8 chore: rename the migration file (#30893) 2026-01-13 15:26:41 +08:00
2d53ba8671 fix: fix object value is optional should skip validate (#30894) 2026-01-13 15:21:06 +08:00
9be863fefa fix: missing content if assistant message with tool_calls (#30083)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-13 12:46:33 +08:00
8f43629cd8 fix(amplitude): update sessionReplaySampleRate default value to 0.5 (#30880)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-13 12:26:50 +08:00
9ee71902c1 fix: fix formatNumber accuracy (#30877) 2026-01-13 11:51:15 +08:00
3f2a461b22 feat: summary index (#30878) 2026-01-13 10:45:12 +08:00
51e2e4a728 feat: summary index 2026-01-13 10:42:24 +08:00
a012c87445 fix: entrypoint.sh overrides NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS when TEXT_GENERATION_TIMEOUT_MS is unset (#30864) (#30865) 2026-01-13 10:12:51 +08:00
450578d4c0 feat(ops): set root span kind for AliyunTrace to enable service-level metrics aggregation (#30728)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-13 10:12:00 +08:00
837237aa6d fix: use node factory for single-step workflow nodes (#30859) 2026-01-13 10:11:18 +08:00
495ad848be feat: implement Summary Index feature. (#30862) 2026-01-13 09:48:14 +08:00
16fa798f21 Merge branch 'deploy/dev' into feat/knowledge-summary-index 2026-01-12 23:44:10 +08:00
b32c93df6f Update api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-12 21:59:19 +08:00
76da8b4ff3 Merge remote-tracking branch 'origin/deploy/dev' 2026-01-12 17:09:25 +08:00
25bfc1cc3b feat: implement Summary Index feature. 2026-01-12 16:52:21 +08:00
b63dfbf654 fix(api): defer streaming response until referenced variables are updated (#30832)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-12 16:23:18 +08:00
51ea87ab85 feat: clear free plan workflow run logs (#29494)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-01-12 15:57:40 +08:00
00698e41b7 build: limit esbuild, glob, docker base version to avoid cve (#30848) 2026-01-12 15:33:20 +08:00
df938a4543 ci: add HITL test env deployment action (#30846) 2026-01-12 15:07:53 +08:00
b34d09649b Merge branch 'feat/support-free-try-app' into deploy/dev 2026-01-12 13:49:28 +08:00
a92df530da mrege main 2026-01-12 13:41:27 +08:00
yyh
9161936f41 refactor(web): extract isServer/isClient utility & upgrade Node.js to 22.12.0 (#30803)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-01-12 12:57:43 +08:00
f9a21b56ab feat: add block-no-verify hook for Claude Code (#30839) 2026-01-12 12:56:05 +08:00
220e1df847 docs(web): add corepack recommendation (#30837)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-12 12:44:30 +08:00
8cfdde594c chore(deps-dev): bump tos from 2.7.2 to 2.9.0 in /api (#30834)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-12 12:44:21 +08:00
31a8fd810c chore(deps-dev): bump @storybook/react from 9.1.13 to 9.1.17 in /web (#30833)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-12 12:44:11 +08:00
939a1b91a0 Merge branch 'feat/ee-workspace-permission-control' into deploy/dev 2026-01-12 12:07:29 +08:00
e8a1c99626 feat: add permission check before owner transfer workspace 2026-01-12 12:04:53 +08:00
898393a0f8 feat: add invite permission check for workspace invite members feature 2026-01-12 11:52:48 +08:00
431936beb9 chore: handle callback warning 2026-01-12 11:33:18 +08:00
163540bf4a chore: handle refetch after created 2026-01-12 11:30:03 +08:00
221130b448 chore: remove old i18n 2026-01-12 10:55:02 +08:00
9fad97ec9b fix: drop useless pyrefly in ci (#30826)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2026-01-12 09:45:49 +08:00
0c2729d9b3 fix: fix refresh token deadlock (#30828) 2026-01-12 09:35:31 +08:00
049925c5d8 Merge branch 'main' into deploy/dev 2026-01-11 21:54:32 +08:00
a2e03b811e fix: Broken import in .storybook/preview.tsx (#30812) 2026-01-10 19:49:23 +08:00
1e10bf525c refactor(models): Refine MessageAgentThought SQLAlchemy typing (#27749)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-10 17:17:45 +09:00
8b1af36d94 feat(web): migrate PWA to Serwist (#30808) 2026-01-10 17:16:18 +09:00
yyh
c56d9e3f69 Merge remote-tracking branch 'origin/main' into deploy/dev 2026-01-09 19:41:17 +08:00
b1eb265fa5 fix: try app not call conversations and sessions 2026-01-09 16:48:03 +08:00
0711dd4159 feat: enhance start node object value check (#30732) 2026-01-09 16:13:17 +08:00
ae0a26f5b6 revert: "fix: fix assign value stand as default (#30651)" (#30717)
The original fix seems correct on its own. However, for chatflows with multiple answer nodes, the `message_replace` command only preserves the output of the last executed answer node.
2026-01-09 16:08:24 +08:00
c2a0950660 fix: button ui problem 2026-01-09 15:34:48 +08:00
bfe98009fd chore: fix dataset problems 2026-01-09 14:26:18 +08:00
ea1704d211 fix: try basic detail errors 2026-01-09 14:14:15 +08:00
a49ab7258d mr trial 2026-01-09 12:16:41 +08:00
425a0f9095 fix trial get 2026-01-09 12:15:40 +08:00
3d050f449c fix trial get 2026-01-09 12:13:29 +08:00
905a5b348d fix trial get 2026-01-09 12:13:20 +08:00
1a1e825685 fix trial get 2026-01-09 12:02:00 +08:00
758b289b6f fix trial get 2026-01-09 12:01:39 +08:00
11be198fc8 Merge branch 'feat/app-trial' into deploy/dev 2026-01-09 11:39:08 +08:00
3e082e6976 fix: migration 2026-01-09 11:38:50 +08:00
cf990cdace Merge branch 'feat/app-trial' into deploy/dev 2026-01-09 11:36:22 +08:00
ce309bd008 mr main 2026-01-09 11:33:10 +08:00
3ed0937734 merge 2026-01-08 18:27:47 +08:00
04912fa775 Merge branch 'feat/llm-support-tools' into deploy/dev 2026-01-08 14:45:08 +08:00
0ff1b61232 Merge branch 'main' into deploy/dev 2026-01-08 14:41:17 +08:00
4d3d8b35d9 Merge branch 'main' into feat/llm-node-support-tools 2026-01-08 14:28:13 +08:00
c323028179 feat: llm node support tools 2026-01-08 14:27:37 +08:00
977406f703 Merge branch 'feat/credit-pool' into deploy/dev 2026-01-08 12:12:10 +08:00
80535ed20e fix 2026-01-08 12:11:50 +08:00
f4f391ada0 fix linter 2026-01-08 11:11:01 +08:00
6d9d2a1079 add rowcount check 2026-01-08 11:10:49 +08:00
4db40646cb add rowcount check 2026-01-08 11:09:52 +08:00
71bd20cf73 add rowcount check 2026-01-08 11:08:34 +08:00
dbc8ffccbd add rowcount check 2026-01-08 11:07:09 +08:00
c3d0c80ca1 Merge branch 'feat/credit-pool' into deploy/dev 2026-01-08 11:03:21 +08:00
75ccba6e52 add rowcount check 2026-01-08 11:01:23 +08:00
8a38489e53 Merge branch 'feat/credit-pool' of github.com:langgenius/dify into feat/credit-pool 2026-01-08 10:45:47 +08:00
c301052789 add rowcount check 2026-01-08 10:45:28 +08:00
38f7c77fa4 Update api/configs/feature/hosted_service/__init__.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-08 10:40:53 +08:00
44762a38c2 Update api/configs/feature/hosted_service/__init__.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-08 10:40:37 +08:00
7938a2bc5b Merge remote-tracking branch 'origin/main' into feat/credit-pool 2026-01-08 10:09:09 +08:00
cfaae1b8a1 Merge branch 'fix/web-app-login-code-encrypt' into deploy/dev 2026-01-07 20:30:03 +08:00
cbc4348727 fix: encrypt password when login at web app login page 2026-01-07 20:29:35 +08:00
c888e3c46e Merge branch 'fix/web-app-login-code-encrypt' into deploy/dev 2026-01-07 20:26:43 +08:00
4fa7bf0128 Merge branch 'feat/agent-node-v2' into deploy/dev 2026-01-07 17:51:08 +08:00
eec57e84e4 Merge branch 'main' into feat/agent-node-v2 2026-01-07 17:34:23 +08:00
dac9c1953a fix: encrypt email login code when login webapp 2026-01-07 17:20:46 +08:00
70149ea05e Merge branch 'main' into feat/llm-node-support-tools 2026-01-07 16:29:47 +08:00
1d93f41fcf feat: llm node support tools 2026-01-07 16:28:41 +08:00
1ecf581ba1 chore: add model name in detail 2026-01-07 15:05:51 +08:00
1584a78fc9 chore: add model name in detail 2026-01-07 15:05:18 +08:00
5a5ee0db26 Merge branch 'feat/agent-node-v2' into deploy/dev 2026-01-06 16:49:35 +08:00
cef7fd484b chore: add trace metadata and streaming icon 2026-01-06 16:30:33 +08:00
3ef37b24bb Merge branch 'feat/model-total-credits' into deploy/dev 2026-01-06 11:08:53 +08:00
88c3286bd4 fix: correct capitalization of 'TONGYI' in model name mapping 2026-01-06 11:05:03 +08:00
04f40303fd Merge branch 'main' into feat/llm-node-support-tools 2026-01-04 18:04:42 +08:00
ececc5ec2c feat: llm node support tools 2026-01-04 18:03:47 +08:00
a29da52562 feat: conditionally render components based on cloud edition status in model provider page and credential panel 2026-01-04 15:12:14 +08:00
3533ed0fdd Merge branch 'main' into feat/model-total-credits 2026-01-04 12:03:37 +08:00
e34513cbc4 fix: correct translation errors in Japanese and Chinese JSON files, and improve type safety in quota panel component 2026-01-04 12:03:02 +08:00
c976c1cb2d feat: add new OpenAI icons (OpenaiBlue, OpenaiTeal, OpenaiViolet) and update existing icons for consistency 2026-01-04 11:43:29 +08:00
9358786c7f Merge branch 'main' into feat/model-total-credits 2026-01-04 11:27:12 +08:00
1629e22d98 update for merge 2026-01-04 11:15:48 +08:00
8fb50d52dc Merge branch 'feat/agent-node-v2' into deploy/dev 2026-01-04 11:10:38 +08:00
dc8a618b6a feat: add think start end tag 2026-01-04 11:09:43 +08:00
d9a0d6caa8 Merge branch 'main' into feat/model-total-credits 2026-01-04 10:59:46 +08:00
f3e7fea628 feat: add tool call time 2026-01-04 10:29:02 +08:00
7c78d42627 fix(web): enable JSON_OBJECT type support in console UI (#30412)
Co-authored-by: zhsama <torvalds@linux.do>
2025-12-31 13:42:38 +08:00
13346f2874 Merge branch 'main' into deploy/dev 2025-12-30 16:45:34 +08:00
d330aaf57a refactor: comment out conditional checks in AmplitudeProvider and GA components
- Commented out the checks for Amplitude and Google Analytics to allow for easier testing and integration.
- This change does not affect the functionality but prepares the components for further enhancements.
2025-12-30 16:28:14 +08:00
9758e36f7c Merge branch 'feat/add-oauth_new_user' into deploy/dev 2025-12-30 15:11:10 +08:00
5ad435ef32 add oauthuser flag for frontend when use oauth login 2025-12-30 15:09:35 +08:00
6ed630db15 Merge branch 'feat/utm-amp' into deploy/dev 2025-12-30 14:52:07 +08:00
256c5231f2 feat: enhance user registration tracking with UTM parameters
- Updated event tracking for user registration success to differentiate between registrations with and without UTM parameters in both OAuth and email flows.
- Adjusted tracking event names accordingly to improve analytics accuracy.
2025-12-30 14:47:25 +08:00
344e0a3318 Merge branch 'main' into feat/utm-amp 2025-12-30 14:42:33 +08:00
59772c2493 feat: integrate Google Analytics event tracking and update CSP for script sources
- Added types for Google Analytics gtag and implemented event tracking in user registration flows.
- Updated Content Security Policy to allow 'wasm-unsafe-eval' in script sources.
- Refactored GA component to improve nonce handling and script loading strategy.
- Cleaned up UTM info cookies after successful user registration.
2025-12-30 14:39:28 +08:00
f828c0e754 chore: handle migrations 2025-12-30 14:13:46 +08:00
960b0707c8 Merge branch 'feat/agent-node-v2' into deploy/dev 2025-12-30 13:43:40 +08:00
03938be789 fix: merge conflict 2025-12-30 11:56:11 +08:00
a6b94f11e5 Merge branch 'main' into deploy/dev 2025-12-30 11:49:47 +08:00
e83635ee5a Merge branch 'main' into feat/llm-node-support-tools 2025-12-30 11:47:54 +08:00
d79372a46d Merge branch 'main' into feat/llm-node-support-tools 2025-12-30 11:47:26 +08:00
bbd11c9e89 feat: llm node support tools 2025-12-30 10:40:01 +08:00
152fd52cd7 [autofix.ci] apply automated fixes 2025-12-30 02:23:25 +00:00
ccabdbc83b Merge branch 'main' into feat/agent-node-v2 2025-12-30 10:20:42 +08:00
56c8221b3f chore: remove frontend changes 2025-12-30 10:19:40 +08:00
d132abcdb4 merge main 2025-12-29 15:55:45 +08:00
d60348572e feat: llm node support tools 2025-12-29 14:55:26 +08:00
3db700db34 Merge branch 'main' into feat/model-total-credits 2025-12-29 14:47:46 +08:00
0982cf6018 Merge branch 'feat/model-total-credits' into deploy/dev 2025-12-25 18:02:14 +08:00
9aee14f4f8 Merge branch 'main' into feat/model-total-credits 2025-12-25 15:13:57 +08:00
10749726eb fix: update translations to replace TONGYI with xAI and adjust related terms in multiple languages 2025-12-25 15:12:43 +08:00
f55faae31b chore: strip reasoning from chatflow answers and persist generation details 2025-12-25 13:59:38 +08:00
cd63bd48bd Merge branch 'feat/model-total-credits' into deploy/dev 2025-12-25 13:59:02 +08:00
9a5a06bb40 feat: add TONGYI model support and update related translations 2025-12-25 13:58:08 +08:00
0cff94d90e Merge branch 'main' into feat/llm-node-support-tools 2025-12-25 13:45:49 +08:00
8408975036 Merge branch 'feat/credit-pool' into deploy/dev 2025-12-25 11:36:27 +08:00
e01fae9151 fix: add tongyi paid 2025-12-25 11:36:06 +08:00
46a8f06b6c Merge branch 'feat/credit-pool' into deploy/dev 2025-12-25 11:35:17 +08:00
0fb97ee9e9 fix: add tongyi paid 2025-12-25 11:34:59 +08:00
b033002d9f add TONGYI icon 2025-12-25 11:15:53 +08:00
130163ca65 mr main 2025-12-25 10:50:35 +08:00
9d3eaefcdd Merge branch 'feat/credit-pool' of github.com:langgenius/dify into feat/credit-pool 2025-12-25 10:43:49 +08:00
5c51d61049 fix migration and add tongyi 2025-12-25 10:43:20 +08:00
ba0a59f998 Merge remote-tracking branch 'origin/main' into feat/credit-pool 2025-12-25 10:38:39 +08:00
7fc25cafb2 feat: basic app add thought field 2025-12-25 10:28:21 +08:00
a0fde9f012 Merge remote-tracking branch 'origin/main' into feat/model-total-credits 2025-12-25 10:22:40 +08:00
a7859de625 feat: llm node support tools 2025-12-24 14:15:55 +08:00
d1b4bb247a Merge branch 'fix/polyfill-toSplice' into deploy/dev 2025-12-23 15:37:23 +08:00
91e23027f1 feat: Add polyfill for Array.prototype.toSpliced method 2025-12-23 15:16:28 +08:00
3b3b16eb00 feat: Add polyfill for Array.prototype.toSpliced method 2025-12-23 14:33:57 +08:00
8cd69899b4 Merge remote-tracking branch 'origin/main' into feat/model-total-credits 2025-12-22 17:47:12 +08:00
710c17ed59 Merge branch 'main' into feat/model-total-credits 2025-12-19 15:34:11 +08:00
8d884aad29 Merge branch 'main' into feat/model-total-credits 2025-12-19 14:56:10 +08:00
7d7cce04eb [autofix.ci] apply automated fixes 2025-12-19 03:53:44 +00:00
88969a609b merge main 2025-12-19 11:50:45 +08:00
00af19c7f2 feat: sandbox retention basic settings 2025-12-18 14:19:33 +08:00
79d7dcaad2 feat: add billing subscription plan api 2025-12-18 12:40:24 +08:00
1a2f37e7c7 feat: add billing subscription plan api 2025-12-18 12:16:59 +08:00
bffd67f3a4 feat: add billing subscription plan api 2025-12-18 11:33:56 +08:00
0bbf8f72b1 feat: add billing subscription plan api 2025-12-18 10:59:32 +08:00
047ea8c143 chore: improve type checking 2025-12-18 10:09:31 +08:00
f54b9b12b0 feat: add process data 2025-12-17 17:34:02 +08:00
cb99b8f04d chore: handle migrations 2025-12-17 15:59:09 +08:00
7c03bcba2b Merge branch 'main' into feat/agent-node-v2 2025-12-17 15:55:27 +08:00
92fa7271ed refactor(llm node): remove unused args 2025-12-17 15:42:23 +08:00
1fcf6e4943 Update 2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py 2025-12-17 11:12:59 +08:00
d3486cab31 refactor(llm node): tool call tool result entity 2025-12-17 10:30:21 +08:00
f4a7efde3d update migration script. 2025-12-16 18:30:12 +08:00
38d4f0fd96 Merge remote-tracking branch 'origin/deploy/dev' 2025-12-16 18:25:54 +08:00
ec4f885dad update migration script. 2025-12-16 18:19:24 +08:00
3781c2a025 [autofix.ci] apply automated fixes 2025-12-16 08:37:32 +00:00
3782f17dc7 Optimize code. 2025-12-16 16:35:15 +08:00
966b586200 Merge branch 'main' into feat/model-total-credits 2025-12-16 16:32:18 +08:00
29698aeed2 Merge remote-tracking branch 'origin/deploy/dev' 2025-12-16 16:26:19 +08:00
15ff8efb15 merge alembic head 2025-12-16 16:20:04 +08:00
407e1c8276 [autofix.ci] apply automated fixes 2025-12-16 08:14:05 +00:00
03cc3868ef feat: update RAG recommended plugins hook to accept type parameter for better categorization 2025-12-16 16:03:14 +08:00
fdd21ac815 fix: update RAGToolRecommendations to use type parameter for plugin retrieval 2025-12-16 16:02:55 +08:00
44524760ee Merge remote-tracking branch 'origin/main' into feat/model-total-credits 2025-12-16 15:51:35 +08:00
e368825c21 Merge remote-tracking branch 'upstream/main' 2025-12-16 15:50:49 +08:00
dd0a870969 Merge branch 'main' into feat/agent-node-v2 2025-12-16 15:17:29 +08:00
0c4c268003 chore: fix ci issues 2025-12-16 15:14:42 +08:00
5d40064f12 feat: Add "type" field to PipelineRecommendedPlugin model; (#29724) 2025-12-16 14:38:23 +08:00
8dad6b6a6d Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-16 14:34:59 +08:00
fd9f5c0fc8 feat: Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. (#29700) 2025-12-16 10:49:31 +08:00
2f54965a72 Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-16 10:43:45 +08:00
e5359ba136 feat: Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. (#29672) 2025-12-15 16:50:50 +08:00
a1a3fa0283 Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-15 16:44:32 +08:00
ff7344f3d3 Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-15 16:38:44 +08:00
bcd33be22a Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-15 16:33:06 +08:00
c1184789e2 fix seed command 2025-12-15 15:41:51 +08:00
ff57848268 [autofix.ci] apply automated fixes 2025-12-15 07:29:20 +00:00
d223fee9b9 Merge branch 'main' into feat/agent-node-v2 2025-12-15 15:26:48 +08:00
ad18d084f3 feat: add sequence output variable. 2025-12-15 14:59:06 +08:00
9941d1f160 feat: add llm log metadata 2025-12-15 14:18:53 +08:00
dab30117da Merge branch 'fix/customer-tool-SSRF' into deploy/dev 2025-12-15 11:00:47 +08:00
3f7d46358c fix test 2025-12-15 10:14:29 +08:00
890b5d222f fix mypy 2025-12-12 17:30:43 +08:00
fefcb1e959 fix CI 2025-12-12 17:30:43 +08:00
da6fa55eed [autofix.ci] apply automated fixes 2025-12-12 17:30:43 +08:00
ad438740c4 refactor the workflowNodeExecution 2025-12-12 17:30:43 +08:00
231ecc1bfe refactor the repo and service 2025-12-12 17:30:43 +08:00
790ed0845e Merge branch 'fix/customer-tool-SSRF' of github.com:langgenius/dify into fix/customer-tool-SSRF 2025-12-12 17:11:00 +08:00
309875650d use squid for ssrf 2025-12-12 17:10:31 +08:00
fc260fab97 [autofix.ci] apply automated fixes 2025-12-12 09:09:29 +00:00
92fa87e729 use squid for ssrf 2025-12-12 17:07:23 +08:00
13fa56b5b1 feat: add tracing metadata 2025-12-12 16:24:49 +08:00
f6accd8ae2 add internal ip filter when parse tool schema 2025-12-12 11:58:05 +08:00
e66ef9145b add internal ip filter when parse tool schema 2025-12-12 11:37:30 +08:00
64f5e34096 Update api/core/helper/ssrf_proxy.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-12 11:31:27 +08:00
3f3b9beeff add internal ip filter when parse tool schema 2025-12-12 11:24:25 +08:00
9ce48b4dc4 fix: llm generation variable 2025-12-12 11:08:49 +08:00
da9a28b9e2 Merge branch 'fix/nextjs-security-update' into deploy/dev 2025-12-12 11:05:39 +08:00
cd09d27f11 nextjs security update;
see more: [security-update-2025-12-11](https://nextjs.org/blog/security-update-2025-12-11)
2025-12-12 10:58:21 +08:00
912ca2bcfe add seed fake logs 2025-12-12 10:19:44 +08:00
dd949a23e1 add clean sandbox workflow runs 2025-12-12 10:19:16 +08:00
dcd95632cb Merge branch 'fix/react-cve-2025' into deploy/dev 2025-12-12 09:37:10 +08:00
75113163a2 fix: upgrade react and react-dom to 19.2.3, fix those cve errors:
- CVE-2025-55184 (DoS, High, CVSS 7.5)
- CVE-2025-67779 (DoS, High, CVSS 7.5)
- CVE-2025-55183 (Source Code Exposure, Medium, CVSS 5.3)
2025-12-12 09:31:58 +08:00
66713091e2 Merge branch 'feat/enchance-warn-user-time-when-need-upgrade-plan' into deploy/dev 2025-12-11 15:05:26 +08:00
00cc50c659 fix: add chunk logic 2025-12-11 15:03:30 +08:00
b459af07ae Merge branch 'fix/DoS-in-Annotation-Import' into deploy/dev 2025-12-11 12:37:50 +08:00
c93f7f6f2c fix annotation import dos 2025-12-11 12:36:16 +08:00
a740fe2292 fix annotation import dos 2025-12-11 12:35:30 +08:00
ba2486d961 Merge branch 'fix/CSV-Injection-in-Annotations-Export' into deploy/dev 2025-12-11 11:25:19 +08:00
e471884dcc fix csv injection in annotations export 2025-12-11 11:22:05 +08:00
ff762e9d84 Merge branch 'feat/enchance-warn-user-time-when-need-upgrade-plan' into deploy/dev 2025-12-11 11:22:03 +08:00
e09174d7d7 chore: i18n 2025-12-11 11:21:45 +08:00
016663bf44 Merge branch 'feat/enchance-warn-user-time-when-need-upgrade-plan' into deploy/dev 2025-12-11 11:00:29 +08:00
824618c2ef chore: spcing problem 2025-12-11 10:59:35 +08:00
918c187c50 refactor: optimize clearDataSourceData function and update localFileList condition 2025-12-11 10:58:24 +08:00
0133843164 refactor: remove supportBatchUpload prop from StepOne component 2025-12-11 10:34:00 +08:00
c49127540f Merge branch 'main' into feat/enchance-warn-user-time-when-need-upgrade-plan 2025-12-11 10:33:39 +08:00
fef5d88f59 Merge branch 'main' into deploy/dev 2025-12-11 10:22:38 +08:00
ba887993aa Merge branch 'feat/enchance-warn-user-time-when-need-upgrade-plan' into deploy/dev 2025-12-11 10:04:50 +08:00
42b6e32574 Merge branch 'main' into deploy/dev 2025-12-11 10:03:30 +08:00
1f91a971a8 feat: implement plan upgrade modal for batch upload restrictions and update supportBatchUpload prop to true in data source components 2025-12-10 18:32:45 +08:00
e487e1f622 refactor: remove supportBatchUpload prop from various components 2025-12-10 17:59:45 +08:00
b3010e35cb chore: add plan upgrade modal test 2025-12-10 17:54:34 +08:00
5891731ab2 Merge branch 'main' into feat/enchance-warn-user-time-when-need-upgrade-plan 2025-12-10 17:19:23 +08:00
383971d56f Merge branch 'feat/enchance-warn-user-time-when-need-upgrade-plan' into deploy/dev 2025-12-10 17:03:40 +08:00
840d45d407 feat: add missing tips 2025-12-10 16:56:11 +08:00
ff40efdc26 mrege 2025-12-10 15:42:19 +08:00
yyh
c6ae13f67b refactor: replace Modal with PlanUpgradeModal in TriggerEventsLimitModal 2025-12-10 15:38:18 +08:00
55dd5f71d7 fix: can add logic 2025-12-10 15:36:18 +08:00
c05d9bd813 Merge branch 'fix/upload-restrictions' into deploy/dev 2025-12-10 15:33:18 +08:00
0404007982 Merge branch 'main' into fix/upload-restrictions 2025-12-10 15:32:14 +08:00
71c20ef3c8 merge 2025-12-10 15:30:09 +08:00
yyh
7171eaf0b5 refactor: replace Modal with PlanUpgradeModal in TriggerEventsLimitModal 2025-12-10 15:26:51 +08:00
b16f87c9b6 feat: can add segement check 2025-12-10 15:24:15 +08:00
yyh
5bbc626b5e refactor: simpilify trigger events limit modal css and props 2025-12-10 15:10:14 +08:00
abb2b860f2 chore: remove unused changes 2025-12-10 15:04:19 +08:00
f7a9aadc98 chore: i18n 2025-12-10 14:41:37 +08:00
bacc9a7970 feat: not to next if multi sent in sandbox 2025-12-10 14:32:34 +08:00
be94274fbd chore: enchance popup modal 2025-12-10 14:09:56 +08:00
4bc230eb13 Merge branch 'fix/upload-restrictions' into deploy/dev 2025-12-10 13:27:02 +08:00
yyh
4d596ad231 Merge branch 'fix/29390-async-window-open' into deploy/dev 2025-12-10 13:19:03 +08:00
yyh
de9fae63b8 fixes #29390: align async window open behavior with inline logic 2025-12-10 13:17:30 +08:00
yyh
67d8e5a7f0 Merge branch 'fix/29390-async-window-open' into deploy/dev 2025-12-10 12:56:02 +08:00
yyh
997ff45e56 fix: harden async window open placeholder logic(#29390) 2025-12-10 12:55:02 +08:00
88508b8631 feat(PageSelector): add isMultipleChoice prop to enhance selection functionality 2025-12-10 12:52:09 +08:00
yyh
eb525b697f Merge branch 'fix/async-window-open' into deploy/dev 2025-12-10 12:26:22 +08:00
yyh
8bf9eee91e fix: maintain loading guard during async billing URL fetch
Add await to openAsync call to prevent multiple concurrent requests
when users rapidly click the billing button. This maintains the
loading state until the billing URL is successfully fetched and opened.

Addresses review feedback about regression in loading behavior
2025-12-10 12:23:56 +08:00
yyh
bac0513b8b improve: better popup blocker detection and type safety
- Add immediate popup blocker detection with user-friendly error message
- Improve type safety by removing any types
- Simplify logic flow in useAsyncWindowOpen hook

Addresses code review suggestions
2025-12-10 12:19:25 +08:00
yyh
e95b7b57c9 fix: prevent popup blocker from blocking async window.open
Add useAsyncWindowOpen hook to handle async URL fetching with placeholder window pattern. This prevents browser popup blockers (especially Safari) from blocking windows opened after async operations.

- Create reusable useAsyncWindowOpen hook with placeholder window pattern
- Fix billing subscription management popup (cloud-plan-item)
- Fix app card explore popup
- Fix app publisher explore popup

Fixes #29389
Ref: #29390
2025-12-10 12:10:47 +08:00
d4a90c43a4 Merge branch 'main' into fix/upload-restrictions 2025-12-10 11:59:51 +08:00
0d2dda7e77 feat: add modals 2025-12-10 11:45:18 +08:00
a44a03696e add expire_on_commit 2025-12-10 00:48:29 +08:00
33b1aaf182 add expire_on_commit 2025-12-10 00:11:35 +08:00
930c36e757 fix: llm detail store 2025-12-09 20:56:54 +08:00
91dad285fa feat: add support for batch upload across various components 2025-12-09 18:29:45 +08:00
2d2ce5df85 feat: generation stream output. 2025-12-09 16:22:17 +08:00
yyh
545a34fbaf Refactor datasets service toward TanStack Query (#29008)
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
2025-12-09 13:44:45 +08:00
61a6c6dbcf feat: introduce init container to automatically fix storage permissions (#29297) 2025-12-09 13:40:10 +08:00
2b23c43434 feat: add agent package 2025-12-09 11:36:47 +08:00
0fb339ca4f fix: saved message 2025-11-18 11:38:12 +08:00
c1871e67aa chore: hide disabed action in try app 2025-11-18 11:28:13 +08:00
f711f9a317 fix: webapp url 2025-11-18 11:22:58 +08:00
9ff3310cb6 chore: handle suggestion readonly 2025-11-18 11:07:01 +08:00
b6bdcc7052 fix: not auther tool in readonly mode 2025-11-18 11:02:46 +08:00
67b0771081 fix: try app not ok in chat 2025-11-17 18:21:43 +08:00
9a07488da9 mrege 2025-11-17 15:42:56 +08:00
ef043c6906 fix: no app not show problem 2025-11-06 14:53:11 +08:00
ab814e3eac fix: inputs overwrite by curr item 2025-10-27 14:08:32 +08:00
a0e1eeb3f1 chore: reset form 2025-10-27 13:57:16 +08:00
b1ebeb67a7 feat: support new chat 2025-10-27 13:50:36 +08:00
082179f70f fix: try chat has not set converstaion 2025-10-27 13:38:41 +08:00
8786ebdbca feat: support use tempalte in create app 2025-10-27 10:58:57 +08:00
b49a4eab62 feat: add app list context 2025-10-24 18:33:54 +08:00
0a7b59f500 feat: add tool requirements to flow 2025-10-24 17:49:29 +08:00
c264d9152f chore: add advanced models 2025-10-24 17:42:38 +08:00
3bf9d898c0 feat: basic app requirements 2025-10-24 17:29:42 +08:00
a7f2849e74 fix: try chatbot ui 2025-10-24 16:22:01 +08:00
0957ece92f fix: the try app always use the curent conversation 2025-10-24 15:57:33 +08:00
949bf38d3c fix: chat setup ui 2025-10-24 15:30:53 +08:00
7bafb7f959 feat: chat info 2025-10-24 14:54:06 +08:00
9735f55ca4 feat: try app alert and i18n 2025-10-24 14:00:24 +08:00
4c1f9b949b feat: alert info and lodash to lodash-es 2025-10-24 11:24:19 +08:00
0af0c94dde fix: preview not full 2025-10-24 10:52:05 +08:00
8e4f0640cc fix: variable readonly in basic app problem 2025-10-24 10:41:18 +08:00
1f513e3b43 chore: remove debug code 2025-10-23 18:26:38 +08:00
aa0841e2a8 chore: 18n 2025-10-23 18:05:34 +08:00
b6a1562357 fix: handle create can not show 2025-10-23 17:54:45 +08:00
bee0797401 feat: create from try app 2025-10-23 17:45:54 +08:00
e085f39c13 chore: description and category 2025-10-23 17:29:32 +08:00
344844d3e0 chore: handle data is large 2025-10-23 16:53:10 +08:00
6e9f82491d chore: reuse the app detail and right meta 2025-10-23 15:51:59 +08:00
372b1c3db8 chore: change detail icon 2025-10-23 15:28:12 +08:00
58d305dbed chore: tab header jp 2025-10-23 15:25:25 +08:00
0360a0416b feat: integration preview page 2025-10-23 15:23:50 +08:00
72282b6e8f feat: try app layout 2025-10-23 14:58:17 +08:00
8391884c4e chore: tab and close btn 2025-10-23 14:45:08 +08:00
b018f2b0a0 feat: can show app detail modal 2025-10-23 14:17:43 +08:00
754f1a3cfa mr main and rebuild migration 2025-10-23 11:14:24 +08:00
b22c28b099 mr main and rebuild migration 2025-10-23 11:14:17 +08:00
ab56b4a818 merge main 2025-10-23 11:12:13 +08:00
cd9e28dbf4 mr main and rebuild migration 2025-10-23 11:11:53 +08:00
04f9637b6f mr main and rebuild migration 2025-10-23 11:11:35 +08:00
b8a29bfb35 fix linter 2025-10-23 11:02:49 +08:00
5e2b0d7b39 add interface for review app 2025-10-23 11:02:49 +08:00
b483d5fad5 fix 2025-10-23 11:02:48 +08:00
04196288f8 fix 2025-10-23 11:02:48 +08:00
cc349e70b1 fix: get app model without check tenant in trial 2025-10-23 11:02:48 +08:00
50bdbfae69 fix: get app model without check tenant in trial 2025-10-23 11:02:48 +08:00
2f45673694 fix: linter 2025-10-23 11:02:48 +08:00
b5fb55069b add: return id for banner list 2025-10-23 11:02:48 +08:00
7ba9d30775 When there is no content in a certain language, it needs to fallback to English 2025-10-23 11:02:48 +08:00
e69b588bad add: language for banner 2025-10-23 11:02:48 +08:00
aadac22ce4 add: language for banner 2025-10-23 11:02:48 +08:00
d12015c722 fix 2025-10-23 11:02:47 +08:00
2641326432 fix 2025-10-23 11:02:47 +08:00
20109553b9 Separate object attributes before session 2025-10-23 11:02:47 +08:00
0e1444d17c fix: session of db 2025-10-23 11:02:47 +08:00
65d376bdae fix trial where condition 2025-10-23 11:02:47 +08:00
e3c1310afa [autofix.ci] apply automated fixes 2025-10-23 11:02:47 +08:00
38da19a729 fix: add marshal app model to json 2025-10-23 11:02:47 +08:00
91110499dd fix: add marshal app model to json 2025-10-23 11:02:47 +08:00
4dca9a12a8 fix: add marshal app model to json 2025-10-23 11:02:47 +08:00
3e448f0102 fix: add marshal site model to json 2025-10-23 11:02:47 +08:00
ca75a1c9a3 add: trial api and trial table 2025-10-23 11:02:43 +08:00
61ebc756aa feat: workflow preview 2025-10-16 17:38:13 +08:00
4bea38042a feat: text completion form preview 2025-10-16 14:03:30 +08:00
337abc536b fix: update responsive breakpoint and adjust divider visibility in banner component 2025-10-16 13:47:38 +08:00
cc02b78aca feat: different app preview 2025-10-16 11:27:58 +08:00
18f2d24f8e chore: preview input field readonly 2025-10-16 10:42:47 +08:00
0c7b9a462f chore: tools preview readonly 2025-10-16 10:36:36 +08:00
4dd5580854 chore: preview two cols in panel 2025-10-15 18:16:57 +08:00
440bd825d8 feat: can show tools in preview 2025-10-15 17:35:59 +08:00
d2379c38bd chore: handle history panel and completion review crash 2025-10-15 17:35:59 +08:00
cbc55c577b Merge branch 'feat/support-free-try-app' of github.com:langgenius/dify into feat/support-free-try-app 2025-10-15 17:20:20 +08:00
8e962d15d1 feat: improve explore page banner component with enhanced layout and responsive styles 2025-10-15 17:20:00 +08:00
b07c766551 chroe: fix ts problem 2025-10-15 16:00:14 +08:00
9e3dd69277 fix: upload btn not sync right 2025-10-15 15:51:18 +08:00
db9e5665c2 fix: docuemnt and aduio show condition in preview 2025-10-15 15:35:49 +08:00
cad77ce0bf chore: audio config readonly 2025-10-15 15:29:09 +08:00
6f4518ebf7 chore: document readonly 2025-10-15 15:27:18 +08:00
a8f5748dee chore: vision readonly 2025-10-15 15:21:23 +08:00
738d3001be chore: chat input and feature readonly 2025-10-15 15:21:22 +08:00
df4e32aaa0 Merge branch 'feat/support-free-try-app' of github.com:langgenius/dify into feat/support-free-try-app 2025-10-15 14:36:47 +08:00
a25e37a96d feat: implement responsive design and resize handling for explore page banner 2025-10-15 14:36:27 +08:00
f156b46705 chore: user input readonly 2025-10-15 13:48:39 +08:00
3b64e118d0 chore: readonly ui 2025-10-15 11:39:41 +08:00
566cd20849 feat: dataset config support readonly 2025-10-15 11:37:12 +08:00
db5c51ffc5 add default trial models 2025-10-15 10:55:02 +08:00
df76527f29 feat: add pause functionality to explore page banner for improved user interaction 2025-10-15 10:36:09 +08:00
dac13ae604 fix: remove unnecessary console log and add rounded corners to Icon in QuotaPanel 2025-10-15 10:16:51 +08:00
53a80a5dbe feat: enhance explore page banner functionality with state management and animation improvements 2025-10-15 09:55:14 +08:00
1507792a0c Merge branch 'feat/support-free-try-app' of github.com:langgenius/dify into feat/support-free-try-app 2025-10-14 18:54:11 +08:00
00b9bbff75 feat: enhance explore page banner functionality with state management and animation improvements 2025-10-14 18:53:29 +08:00
e1f8b4b387 feat: support show dataset in knowledge 2025-10-14 18:31:42 +08:00
1539d86f7d chore: instruction and vars to readonly 2025-10-14 17:28:49 +08:00
67bb14d3ee chore: update dependencies and improve explore page banner 2025-10-14 15:51:07 +08:00
5653309080 feat: add carousel & new banner of explore page 2025-10-14 15:41:22 +08:00
0f52b34b61 feat: try apps basic app preveiw 2025-10-14 15:38:22 +08:00
75e35857c1 feat: add carousel & new banner of explore page 2025-10-14 14:17:49 +08:00
4f81be70e3 feat: no apps 2025-10-13 18:31:57 +08:00
1d4d627d05 feat: toogle sidebar 2025-10-13 17:36:24 +08:00
2357234f39 chore: sidebar ui 2025-10-13 17:11:51 +08:00
a3f7d8f996 chore: merge main 2025-10-13 16:38:29 +08:00
56f12e70c1 chore: web apps copywritings 2025-10-13 16:18:57 +08:00
b14afda160 chore: app gallary nav 2025-10-13 15:40:13 +08:00
44b4948972 chore: explore card ui and permission 2025-10-13 15:07:25 +08:00
487eac3b91 chore: add banner permission 2025-10-13 11:27:50 +08:00
84b2913cd9 feat: filter title 2025-10-13 11:12:10 +08:00
176d810c8d chore: update category ui 2025-10-13 10:55:49 +08:00
8b9a9d0574 feat: integrate loading state in QuotaPanel and update ModelProviderPage to handle workspace validation 2025-10-11 17:39:58 +08:00
9e66564526 feat: banner placeholder 2025-10-11 15:07:03 +08:00
781a9a56cd feat: explore title change 2025-10-11 14:58:54 +08:00
a49321775c Merge remote-tracking branch 'origin/main' into feat/model-total-credits 2025-10-11 14:02:39 +08:00
0c83f62848 feat: add new LLM icons and update related components for improved model support 2025-10-11 14:02:24 +08:00
1ddd4bc549 mr credit pool 2025-10-11 12:04:18 +08:00
1d16528dff add credit next_credit_reset_date 2025-10-11 11:45:24 +08:00
c0ed353c10 add credit next_credit_reset_date 2025-10-11 11:45:18 +08:00
93be1219eb chore: try app title 2025-10-11 11:00:26 +08:00
3276d6429d chore: handle completion acion 2025-10-11 10:53:24 +08:00
50072a63ae feat: support try agent app 2025-10-11 10:42:55 +08:00
1ab7e1cba8 fix: try chatflow run url problem 2025-10-11 10:11:14 +08:00
b0aef35c63 feat: try chat flow app 2025-10-10 18:24:56 +08:00
ac351b700c chore: some ui 2025-10-10 16:51:49 +08:00
d1e5d30ea9 fix: text generation api url 2025-10-10 16:39:42 +08:00
c73e84d992 feat: can show text completion run result pages 2025-10-10 16:34:10 +08:00
09998612e7 Merge branch 'feat/credit-pool' of github.com:langgenius/dify into feat/credit-pool 2025-10-09 11:09:32 +08:00
f71ad55d58 fix test case 2025-10-09 11:08:52 +08:00
5b81397054 fix test case 2025-10-09 11:08:24 +08:00
e056e0835a [autofix.ci] apply automated fixes (attempt 2/3) 2025-10-09 02:51:51 +00:00
e1819fb7e5 [autofix.ci] apply automated fixes 2025-10-09 02:50:01 +00:00
0360f0b33b fix: create paid provider auto 2025-09-26 14:32:24 +08:00
560fe8a0f6 fix: format 2025-09-26 13:33:32 +08:00
da27d261b0 fix: add paid quota error for init_anthropic 2025-09-26 13:32:57 +08:00
c3e3a18ab4 add paid credit 2025-09-26 13:32:40 +08:00
ab34cea714 add paid credit 2025-09-26 13:32:28 +08:00
db0780cfa8 add:log 2025-09-26 13:31:54 +08:00
2b51fc23d9 add credit pool sys 2025-09-26 13:29:31 +08:00
5f0bd5119a chore: temp 2025-09-24 13:39:52 +08:00
8353352bda chore: try app can use web app run 2025-09-22 15:17:11 +08:00
73845cbec5 feat: text generation 2025-09-19 16:32:11 +08:00
c2f94e9e8a feat: api call the try app and support disable feedback 2025-09-19 11:32:30 +08:00
e54efda36f feat: try app page 2025-09-18 14:54:15 +08:00
d4bd19f6d8 fix: api login detect problems 2025-09-17 17:15:23 +08:00
4decbbbf18 chore: remove useless api 2025-09-17 14:34:59 +08:00
b15867f92e chore: feedback api 2025-09-17 14:12:34 +08:00
a5e5fbc6e0 chore: some api change to new 2025-09-17 14:10:56 +08:00
1b1471b6d8 fix: stop response api 2025-09-17 14:07:15 +08:00
5280bffde2 feat: change api to new 2025-09-17 11:17:12 +08:00
db0fc94b39 chore: change api to support try apps 2025-09-16 18:21:23 +08:00
424 changed files with 22572 additions and 4015 deletions

View File

@ -5,5 +5,18 @@
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true,
"ralph-loop@claude-plugins-official": true
},
"hooks": {
"PreToolUse": [
{
"matcher": "Bash",
"hooks": [
{
"type": "command",
"command": "npx -y block-no-verify@1.1.1"
}
]
}
]
}
}

View File

@ -39,12 +39,6 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly check
run: |
cd api
uv add --dev pyrefly
uv run pyrefly check || true
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py

29
.github/workflows/deploy-hitl.yml vendored Normal file
View File

@ -0,0 +1,29 @@
name: Deploy HITL
on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "feat/hitl-frontend"
- "feat/hitl-backend"
types:
- completed
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
(
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
)
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.HITL_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}

1
.nvmrc
View File

@ -1 +0,0 @@
22.11.0

View File

@ -589,6 +589,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true

View File

@ -1,4 +1,5 @@
import base64
import datetime
import json
import logging
import secrets
@ -34,7 +35,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
@ -45,6 +46,7 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@ -62,8 +64,10 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@ -84,7 +88,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(email)
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@ -100,20 +104,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
email_validate(new_email)
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = new_email
account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@ -658,7 +664,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return
# Create account
email = email.strip()
email = email.strip().lower()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
@ -852,6 +858,61 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.")
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
)
@click.option(
"--dry-run",
is_flag=True,
help="Preview cleanup results without deleting any workflow run data.",
)
def clean_workflow_runs(
days: int,
batch_size: int,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
dry_run: bool,
):
"""
Clean workflow runs and related workflow data for free tenants.
"""
if (start_from is None) ^ (end_before is None):
raise click.UsageError("--start-from and --end-before must be provided together.")
start_time = datetime.datetime.now(datetime.UTC)
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
WorkflowRunCleanup(
days=days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
).run()
end_time = datetime.datetime.now(datetime.UTC)
elapsed = end_time - start_time
click.echo(
click.style(
f"Workflow run cleanup completed. start={start_time.isoformat()} "
f"end={end_time.isoformat()} duration={elapsed}",
fg="green",
)
)
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records(force: bool):

View File

@ -959,6 +959,16 @@ class MailConfig(BaseSettings):
default=None,
)
ENABLE_TRIAL_APP: bool = Field(
description="Enable trial app",
default=False,
)
ENABLE_EXPLORE_BANNER: bool = Field(
description="Enable explore banner",
default=False,
)
class RagEtlConfig(BaseSettings):
"""
@ -1101,6 +1111,10 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable clean messages task",
default=False,
)
ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
description="Enable scheduled workflow run cleanup task",
default=False,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,

View File

@ -107,10 +107,12 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
banner,
installed_app,
parameter,
recommended_app,
saved_message,
trial,
)
# Import tag controllers
@ -145,6 +147,7 @@ __all__ = [
"apikey",
"app",
"audio",
"banner",
"billing",
"bp",
"completion",
@ -198,6 +201,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trial",
"trigger_providers",
"version",
"website",

View File

@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
P = ParamSpec("P")
R = TypeVar("R")
@ -32,6 +32,8 @@ class InsertExploreAppPayload(BaseModel):
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
can_trial: bool = Field(default=False)
trial_limit: int = Field(default=0)
@field_validator("language")
@classmethod
@ -39,11 +41,33 @@ class InsertExploreAppPayload(BaseModel):
return supported_language(value)
class InsertExploreBannerPayload(BaseModel):
category: str = Field(...)
title: str = Field(...)
description: str = Field(...)
img_src: str = Field(..., alias="img-src")
language: str = Field(default="en-US")
link: str = Field(...)
sort: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
model_config = {"populate_by_name": True}
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
InsertExploreBannerPayload.__name__,
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]):
@wraps(view)
@ -109,6 +133,20 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@ -123,6 +161,20 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = payload.category
recommended_app.position = payload.position
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@ -168,7 +220,62 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
trial_app = session.execute(
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
).scalar_one_or_none()
if trial_app:
session.delete(trial_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/admin/insert-explore-banner")
class InsertExploreBannerApi(Resource):
@console_ns.doc("insert_explore_banner")
@console_ns.doc(description="Insert an explore banner")
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
@console_ns.response(201, "Banner inserted successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
content = {
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
}
banner = ExporleBanner(
content=content,
link=payload.link,
sort=payload.sort,
language=payload.language,
)
db.session.add(banner)
db.session.commit()
return {"result": "success"}, 201
@console_ns.route("/admin/insert-explore-banner/<uuid:banner_id>")
class DeleteExploreBannerApi(Resource):
@console_ns.doc("delete_explore_banner")
@console_ns.doc(description="Delete an explore banner")
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
@console_ns.response(204, "Banner deleted successfully")
@only_edition_cloud
@admin_required
def delete(self, banner_id):
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
if not banner:
raise NotFound(f"Banner '{banner_id}' is not found")
db.session.delete(banner)
db.session.commit()
return {"result": "success"}, 204

View File

@ -272,6 +272,7 @@ class AnnotationExportApi(Resource):
@account_initialization_required
@edit_permission_required
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response_data = {"data": marshal(annotation_list, annotation_fields)}
@ -359,7 +360,6 @@ class AnnotationBatchImportApi(Resource):
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(

View File

@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
class NeedAddIdsError(BaseHTTPException):
error_code = "need_add_ids"
description = "Need to add ids."
code = 400

View File

@ -202,6 +202,7 @@ message_detail_model = console_ns.model(
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
"generation_detail": fields.Raw,
},
)

View File

@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None:
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator
else:
return decorator(view)
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
app_id = kwargs.get("app_id")
app_id = str(app_id)
del kwargs["app_id"]
app_model = _load_app_model_with_trial(app_id)
if not app_model:
raise AppNotFoundError()
app_mode = AppMode.value_of(app_model.mode)
if mode is not None:
if isinstance(mode, list):
modes = mode
else:
modes = [mode]
if app_mode not in modes:
mode_values = {m.value for m in modes}
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
kwargs["app_model"] = app_model
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -63,10 +63,9 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id
reg_email = args.email
token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
@ -100,11 +99,12 @@ class ActivateApi(Resource):
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
normalized_request_email = args.email.lower() if args.email else None
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
account = invitation["account"]
account.name = args.name

View File

@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages:
language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
user_email = args.email
user_email = args.email.lower()
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args.email)
AccountService.add_email_register_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_email_register_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/email-register")
@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
normalized_email = email.lower()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(email, args.password_confirm)
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(email)
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
def _create_new_account(self, email, password) -> Account | None:
def _create_new_account(self, email: str, password: str) -> Account | None:
# Create new account if allowed
account = None
try:

View File

@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
account=account,
email=args.email,
email=normalized_email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
user_email = args.email
user_email = args.email.lower()
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args.email)
AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args.code, additional_data={"phase": "reset"}
token_email, code=args.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/forgot-password/resets")
@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@ -90,32 +90,38 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
request_email = args.email
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None
if args.invite_token:
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:
invite_token = None
try:
if invitation_data:
data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None
if invitee_email != args.email:
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
raise InvalidEmailError()
account = AccountService.authenticate(args.email, args.password, args.invite_token)
else:
account = AccountService.authenticate(args.email, args.password)
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@ -130,7 +136,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email)
AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args.email)
account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
email=args.email,
email=normalized_email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args.email)
account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args.email, language=language)
token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
else:
raise AccountNotFound()
else:
@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
user_email = args.email
original_email = args.email
user_email = original_email.lower()
language = args.language
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args.email:
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args.code:
@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token)
try:
account = AccountService.get_user_through_email(user_email)
account = _get_account_with_case_fallback(original_email)
except AccountRegisterError:
raise AccountInFreezeError()
if account:
@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email)
AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
def _get_account_with_case_fallback(email: str):
account = AccountService.get_user_through_email(email)
if account or email == email.lower():
return account
return AccountService.get_user_through_email(email.lower())
def _authenticate_account_with_case_fallback(
original_email: str, normalized_email: str, password: str, invite_token: str | None
):
try:
return AccountService.authenticate(original_email, password, invite_token)
except services.errors.account.AccountPasswordError:
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)

View File

@ -3,7 +3,6 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@ -118,7 +117,10 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
invitation_email_normalized = (
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
)
if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@ -159,10 +161,7 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
base_url = dify_config.CONSOLE_WEB_URL
query_char = "&" if "?" in base_url else "?"
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
response = redirect(target_url)
response = redirect(f"{dify_config.CONSOLE_WEB_URL}?oauth_new_user={str(oauth_new_user).lower()}")
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
@ -175,7 +174,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account
@ -197,9 +196,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
tenant_was_created.send(new_tenant)
if not account:
normalized_email = user_info.email.lower()
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@ -210,7 +210,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
)
# Set interface language

View File

@ -146,6 +146,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
summary_index_setting: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None

View File

@ -39,9 +39,10 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.dataset import DocumentPipelineExecutionLog, DocumentSegmentSummary
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from tasks.generate_summary_index_task import generate_summary_index_task
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@ -104,6 +105,10 @@ class DocumentRenamePayload(BaseModel):
name: str
class GenerateSummaryPayload(BaseModel):
document_list: list[str]
register_schema_models(
console_ns,
KnowledgeConfig,
@ -111,6 +116,7 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
)
@ -295,6 +301,97 @@ class DatasetDocumentListApi(Resource):
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
# Check if dataset has summary index enabled
has_summary_index = (
dataset.summary_index_setting
and dataset.summary_index_setting.get("enable") is True
)
# Filter documents that need summary calculation
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
summary_status_map = {}
if has_summary_index and document_ids_need_summary:
# Get all segments for these documents (excluding qa_model and re_segment)
segments = (
db.session.query(DocumentSegment.id, DocumentSegment.document_id)
.where(
DocumentSegment.document_id.in_(document_ids_need_summary),
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == current_tenant_id,
)
.all()
)
# Group segments by document_id
document_segments_map = {}
for segment in segments:
doc_id = str(segment.document_id)
if doc_id not in document_segments_map:
document_segments_map[doc_id] = []
document_segments_map[doc_id].append(segment.id)
# Get all summary records for these segments
all_segment_ids = [seg.id for seg in segments]
summaries = {}
if all_segment_ids:
summary_records = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id.in_(all_segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only count enabled summaries
)
.all()
)
summaries = {summary.chunk_id: summary.status for summary in summary_records}
# Calculate summary_index_status for each document
for doc_id in document_ids_need_summary:
segment_ids = document_segments_map.get(doc_id, [])
if not segment_ids:
# No segments, status is "GENERATING" (waiting to generate)
summary_status_map[doc_id] = "GENERATING"
continue
# Count summary statuses for this document's segments
status_counts = {"completed": 0, "generating": 0, "error": 0, "not_started": 0}
for segment_id in segment_ids:
status = summaries.get(segment_id, "not_started")
if status in status_counts:
status_counts[status] += 1
else:
status_counts["not_started"] += 1
total_segments = len(segment_ids)
completed_count = status_counts["completed"]
generating_count = status_counts["generating"]
error_count = status_counts["error"]
# Determine overall status (only three states: GENERATING, COMPLETED, ERROR)
if completed_count == total_segments:
summary_status_map[doc_id] = "COMPLETED"
elif error_count > 0:
# Has errors (even if some are completed or generating)
summary_status_map[doc_id] = "ERROR"
elif generating_count > 0 or status_counts["not_started"] > 0:
# Still generating or not started
summary_status_map[doc_id] = "GENERATING"
else:
# Default to generating
summary_status_map[doc_id] = "GENERATING"
# Add summary_index_status to each document
for document in documents:
if has_summary_index and document.need_summary is True:
document.summary_index_status = summary_status_map.get(str(document.id), "GENERATING")
else:
# Return null if summary index is not enabled or document doesn't need summary
document.summary_index_status = None
if fetch:
for document in documents:
completed_segments = (
@ -393,6 +490,7 @@ class DatasetDocumentListApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/datasets/init")
class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@ -780,6 +878,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@ -815,6 +914,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
return response, 200
@ -1182,3 +1282,211 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
class DocumentGenerateSummaryApi(Resource):
@console_ns.doc("generate_summary_for_documents")
@console_ns.doc(description="Generate summary index for documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
@console_ns.response(200, "Summary generation started successfully")
@console_ns.response(400, "Invalid request or dataset configuration")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""
Generate summary index for specified documents.
This endpoint checks if the dataset configuration supports summary generation
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Validate request payload
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
document_list = payload.document_list
if not document_list:
raise ValueError("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"
)
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError(
"Summary index is not enabled for this dataset. "
"Please enable it in the dataset settings."
)
# Verify all documents exist and belong to the dataset
documents = (
db.session.query(Document)
.filter(
Document.id.in_(document_list),
Document.dataset_id == dataset_id,
)
.all()
)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries
if document.doc_form == "qa_model":
logger.info(
f"Skipping summary generation for qa_model document {document.id}"
)
continue
# Dispatch async task
generate_summary_index_task(dataset_id, document.id)
logger.info(
f"Dispatched summary generation task for document {document.id} in dataset {dataset_id}"
)
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
class DocumentSummaryStatusApi(DocumentResource):
@console_ns.doc("get_document_summary_status")
@console_ns.doc(description="Get summary index generation status for a document")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Summary status retrieved successfully")
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
Returns:
- total_segments: Total number of segments in the document
- summary_status: Dictionary with status counts
- completed: Number of summaries completed
- generating: Number of summaries being generated
- error: Number of summaries with errors
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get document
document = self.get_document(dataset_id, document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Get all segments for this document
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
total_segments = len(segments)
# Get all summary records for these segments
segment_ids = [segment.id for segment in segments]
summaries = []
if segment_ids:
summaries = (
db.session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.document_id == document_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
)
.all()
)
# Create a mapping of chunk_id to summary
summary_map = {summary.chunk_id: summary for summary in summaries}
# Count statuses
status_counts = {
"completed": 0,
"generating": 0,
"error": 0,
"not_started": 0,
}
summary_list = []
for segment in segments:
summary = summary_map.get(segment.id)
if summary:
status = summary.status
status_counts[status] = status_counts.get(status, 0) + 1
summary_list.append({
"segment_id": segment.id,
"segment_position": segment.position,
"status": summary.status,
"summary_preview": summary.summary_content[:100] + "..." if summary.summary_content and len(summary.summary_content) > 100 else summary.summary_content,
"error": summary.error,
"created_at": int(summary.created_at.timestamp()) if summary.created_at else None,
"updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None,
})
else:
status_counts["not_started"] += 1
summary_list.append({
"segment_id": segment.id,
"segment_position": segment.position,
"status": "not_started",
"summary_preview": None,
"error": None,
"created_at": None,
"updated_at": None,
})
return {
"total_segments": total_segments,
"summary_status": status_counts,
"summaries": summary_list,
}, 200

View File

@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.dataset import ChildChunk, DocumentSegment, DocumentSegmentSummary
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
@ -41,6 +41,23 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
segment_dict = marshal(segment, segment_fields)
# Query summary for this segment (only enabled summaries)
summary = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
)
.first()
)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@ -63,6 +80,7 @@ class SegmentUpdatePayload(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class BatchImportPayload(BaseModel):
@ -180,8 +198,34 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
if segment_ids:
summary_records = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
)
.all()
)
# Only include enabled summaries
summaries = {
summary.chunk_id: summary.summary_content
for summary in summary_records
if summary.enabled is True
}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = marshal(segment, segment_fields)
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": marshal(segments.items, segment_fields),
"data": segments_with_summary,
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
@ -327,7 +371,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@ -389,10 +433,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@setup_required
@login_required

View File

@ -1,4 +1,4 @@
from flask_restx import Resource
from flask_restx import Resource, fields
from controllers.common.schema import register_schema_model
from libs.login import login_required
@ -10,17 +10,56 @@ from ..wraps import (
cloud_edition_billing_rate_limit_check,
setup_required,
)
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
register_schema_model(console_ns, HitTestingPayload)
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required

View File

@ -0,0 +1,43 @@
from flask import request
from flask_restx import Resource
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.model import ExporleBanner
class BannerApi(Resource):
"""Resource for banner list."""
@explore_banner_enabled
def get(self):
"""Get banner list."""
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
# Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
# Fallback to en-US if no banners found and language is not en-US
if not banners and language != "en-US":
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
# Convert banners to serializable format
result = []
for banner in banners:
banner_data = {
"id": banner.id,
"content": banner.content, # Already parsed as JSON by SQLAlchemy
"link": banner.link,
"sort": banner.sort,
"status": banner.status,
"created_at": banner.created_at.isoformat() if banner.created_at else None,
}
result.append(banner_data)
return result
api.add_resource(BannerApi, "/explore/banners")

View File

@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
class TrialAppNotAllowed(BaseHTTPException):
"""*403* `Trial App Not Allowed`
Raise if the user has reached the trial app limit.
"""
error_code = "trial_app_not_allowed"
code = 403
description = "the app is not allowed to be trial."
class TrialAppLimitExceeded(BaseHTTPException):
"""*403* `Trial App Limit Exceeded`
Raise if the user has exceeded the trial app limit.
"""
error_code = "trial_app_limit_exceeded"
code = 403
description = "The user has exceeded the trial app limit."

View File

@ -29,6 +29,7 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_list_fields = {

View File

@ -0,0 +1,512 @@
import logging
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
NeedAddIdsError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model_with_trial
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
NotWorkflowAppError,
)
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from fields.dataset_fields import dataset_fields
from fields.workflow_fields import workflow_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.account import TenantStatus
from models.model import AppMode, Site
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.app_service import AppService
from services.audio_service import AudioService
from services.dataset_service import DatasetService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
class TrialAppWorkflowRunApi(TrialAppResource):
def post(self, trial_app):
"""
Run workflow
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
assert current_user is not None
try:
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialAppWorkflowTaskStopApi(TrialAppResource):
def post(self, trial_app, task_id: str):
"""
Stop workflow task
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}
class TrialChatApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
class TrialChatAudioApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
file = request.files["file"]
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialChatTextApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialCompletionApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
assert app_model.tenant
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return SiteResponse.model_validate(site).model_dump(mode="json")
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app parameters."""
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return ParametersResponse.model_validate(parameters).model_dump(mode="json")
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
app_model = app_service.get_app(app_model)
return app_model
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(workflow_fields)
def get(self, app_model):
"""Get workflow detail"""
if not app_model.workflow_id:
raise AppUnavailableError()
workflow = (
db.session.query(Workflow)
.where(
Workflow.id == app_model.workflow_id,
)
.first()
)
return workflow
class DatasetListApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
tenant_id = app_model.tenant_id
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
else:
raise NeedAddIdsError()
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
api.add_resource(
TrialMessageSuggestedQuestionApi,
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="trial_app_suggested_question",
)
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")

View File

@ -2,14 +2,15 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import InstalledApp
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -71,6 +72,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
if trial_app is None:
raise TrialAppNotAllowed()
app = trial_app.app
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:
raise TrialAppLimitExceeded()
return view(app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_trial_app:
abort(403, "Trial app feature is not enabled.")
return view(*args, **kwargs)
return decorated
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_explore_banner:
abort(403, "Explore banner feature is not enabled.")
return view(*args, **kwargs)
return decorated
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@ -80,3 +136,13 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
class TrialAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [
trial_app_required,
account_initialization_required,
login_required,
]

View File

@ -84,10 +84,11 @@ class SetupApi(Resource):
raise NotInitValidateError()
args = SetupRequestPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
# setup
RegisterService.setup(
email=args.email,
email=normalized_email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),

View File

@ -41,7 +41,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
else:
language = "en-US"
account = None
user_email = args.email
user_email = None
email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email != current_user.email:
if user_email.lower() != current_user.email.lower():
raise InvalidEmailError()
user_email = current_user.email
else:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
email_for_sending = account.email
user_email = account.email
token = AccountService.send_change_email_email(
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
account=account,
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
)
return {"result": "success", "data": token}
@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args.email
user_email = args.email.lower()
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args.email)
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_change_email_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/account/change-email/reset")
@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
normalized_new_email = args.new_email.lower()
if AccountService.is_account_in_freeze(args.new_email):
if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args.new_email):
if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(
email=args.new_email,
email=normalized_new_email,
)
return updated_account
@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
def post(self):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
if AccountService.is_account_in_freeze(args.email):
normalized_email = args.email.lower()
if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args.email):
if not AccountService.check_email_unique(normalized_email):
raise EmailAlreadyInUseError()
return {"result": "success"}

View File

@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(invitee_email)
encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return {
"result": "success",

View File

@ -358,14 +358,12 @@ def annotation_import_rate_limit(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
# Check per-minute rate limit
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
redis_client.zadd(minute_key, {current_time: current_time})
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
minute_count = redis_client.zcard(minute_key)
redis_client.expire(minute_key, 120) # 2 minutes TTL
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
abort(
429,
@ -379,7 +377,6 @@ def annotation_import_rate_limit(view: Callable[P, R]):
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
hour_count = redis_client.zcard(hour_key)
redis_client.expire(hour_key, 7200) # 2 hours TTL
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
abort(
429,

View File

@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
@ -22,7 +21,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from models.account import Account
from services.account_service import AccountService
@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
def post(self):
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
request_email = payload.email
normalized_email = request_email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
raise AuthenticationFailedError()
else:
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
return {"result": "success", "data": token}
@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
user_email = payload.email
user_email = payload.email.lower()
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
raise InvalidEmailError()
if payload.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(payload.email)
AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=payload.code, additional_data={"phase": "reset"}
token_email, code=payload.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(payload.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@web_ns.route("/forgot-password/resets")
@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource):
)
args = parser.parse_args()
user_email = args["email"]
user_email = args["email"].lower()
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(user_email)
account = WebAppAuthService.get_user_through_email(token_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"])
AccountService.reset_login_error_rate_limit(user_email)
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response

View File

@ -0,0 +1,380 @@
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentEntity, AgentLog, AgentResult
from core.agent.patterns.strategy_factory import StrategyFactory
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMUsage,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
logger = logging.getLogger(__name__)
class AgentAppRunner(BaseAgentRunner):
def _create_tool_invoke_hook(self, message: Message):
"""
Create a tool invoke hook that uses ToolEngine.agent_invoke.
This hook handles file creation and returns proper meta information.
"""
# Get trace manager from app generate entity
trace_manager = self.application_generate_entity.trace_manager
def tool_invoke_hook(
tool: Tool, tool_args: dict[str, Any], tool_name: str
) -> tuple[str, list[str], ToolInvokeMeta]:
"""Hook that uses agent_invoke for proper file and meta handling."""
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool,
tool_parameters=tool_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=message.id,
conversation_id=self.conversation.id,
)
# Publish files and track IDs
for message_file_id in message_files:
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id),
PublishFrom.APPLICATION_MANAGER,
)
self._current_message_file_ids.append(message_file_id)
return tool_invoke_response, message_files, tool_invoke_meta
return tool_invoke_hook
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run Agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, _ = self._init_prompt_tools()
assert app_config.agent
# Create tool invoke hook for agent_invoke
tool_invoke_hook = self._create_tool_invoke_hook(message)
# Get instruction for ReAct strategy
instruction = self.app_config.prompt_template.simple_prompt_template or ""
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
model_features=self.model_features,
model_instance=self.model_instance,
tools=list(tool_instances.values()),
files=list(self.files),
max_iterations=app_config.agent.max_iteration,
context=self.build_execution_context(),
agent_strategy=self.config.strategy,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Initialize state variables
current_agent_thought_id = None
has_published_thought = False
current_tool_name: str | None = None
self._current_message_file_ids: list[str] = []
# organize prompt messages
prompt_messages = self._organize_prompt_messages()
# Run strategy
generator = strategy.run(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
stop=app_generate_entity.model_conf.stop,
stream=True,
)
# Consume generator and collect result
result: AgentResult | None = None
try:
while True:
try:
output = next(generator)
except StopIteration as e:
# Generator finished, get the return value
result = e.value
break
if isinstance(output, LLMResultChunk):
# Handle LLM chunk
if current_agent_thought_id and not has_published_thought:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
has_published_thought = True
yield output
elif isinstance(output, AgentLog):
# Handle Agent Log using log_type for type-safe dispatch
if output.status == AgentLog.LogStatus.START:
if output.log_type == AgentLog.LogType.ROUND:
# Start of a new round
message_file_ids: list[str] = []
current_agent_thought_id = self.create_agent_thought(
message_id=message.id,
message="",
tool_name="",
tool_input="",
messages_ids=message_file_ids,
)
has_published_thought = False
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call start - extract data from structured fields
current_tool_name = output.data.get("tool_name", "")
tool_input = output.data.get("tool_args", {})
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=current_tool_name,
tool_input=tool_input,
thought=None,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.status == AgentLog.LogStatus.SUCCESS:
if output.log_type == AgentLog.LogType.THOUGHT:
if current_agent_thought_id is None:
continue
thought_text = output.data.get("thought")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=thought_text,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call finished
tool_output = output.data.get("output")
# Get meta from strategy output (now properly populated)
tool_meta = output.data.get("meta")
# Wrap tool_meta with tool_name as key (required by agent_service)
if tool_meta and current_tool_name:
tool_meta = {current_tool_name: tool_meta}
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=None,
observation=tool_output,
tool_invoke_meta=tool_meta,
answer=None,
messages_ids=self._current_message_file_ids,
)
# Clear message file ids after saving
self._current_message_file_ids = []
current_tool_name = None
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.ROUND:
if current_agent_thought_id is None:
continue
# Round finished - save LLM usage and answer
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
llm_result = output.data.get("llm_result")
final_answer = output.data.get("final_answer")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=llm_result,
observation=None,
tool_invoke_meta=None,
answer=final_answer,
messages_ids=[],
llm_usage=llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
except Exception:
# Re-raise any other exceptions
raise
# Process final result
if isinstance(result, AgentResult):
final_answer = result.text
usage = result.usage or LLMUsage.empty_usage()
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""
if not prompt_template:
return prompt_messages or []
prompt_messages = prompt_messages or []
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
return prompt_messages
if not prompt_messages:
return [SystemPromptMessage(content=prompt_template)]
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
# For ReAct strategy, use the agent prompt template
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
prompt_template = self.config.prompt.first_prompt
else:
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@ -1,11 +1,12 @@
import json
import logging
import uuid
from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@ -114,9 +116,20 @@ class BaseAgentRunner(AppRunner):
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.model_features = features
self.query: str | None = ""
self._current_thoughts: list[PromptMessage] = []
def build_execution_context(self) -> ExecutionContext:
"""Build execution context."""
return ExecutionContext(
user_id=self.user_id,
app_id=self.app_config.app_id,
conversation_id=self.conversation.id,
message_id=self.message.id,
tenant_id=self.tenant_id,
)
def _repack_app_generate_entity(
self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:
@ -289,6 +302,7 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@ -296,20 +310,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
answer_unit_price=Decimal(0),
answer_price_unit=Decimal("0.001"),
tokens=0,
total_price=0,
total_price=Decimal(0),
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)
@ -342,7 +356,8 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
agent_thought.thought += thought
existing_thought = agent_thought.thought or ""
agent_thought.thought = f"{existing_thought}{thought}"
if tool_name:
agent_thought.tool = tool_name
@ -440,21 +455,30 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_names_raw = agent_thought.tool
if tool_names_raw:
tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
tool_input_payload = agent_thought.tool_input
if tool_input_payload:
try:
tool_inputs = json.loads(tool_input_payload)
except Exception:
tool_inputs = {tool: {} for tool in tool_names}
else:
tool_inputs = {tool: {} for tool in tool_names}
for tool in tools:
observation_payload = agent_thought.observation
if observation_payload:
try:
tool_responses = json.loads(observation_payload)
except Exception:
tool_responses = dict.fromkeys(tool_names, observation_payload)
else:
tool_responses = dict.fromkeys(tool_names, observation_payload)
for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@ -484,7 +508,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
if not tools:
if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:

View File

@ -1,437 +0,0 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]
def run(
self,
message: Message,
query: str,
inputs: Mapping[str, str],
) -> Generator:
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query)
trace_manager = app_generate_entity.trace_manager
# check model mode
if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
assert app_config.agent
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template or ""
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
agent_thought_id = "" # Initialize agent_thought_id
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
self._prompt_messages_tools = []
message_file_ids: list[str] = []
agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
stop=app_generate_entity.model_conf.stop,
stream=True,
user=self.user_id,
callbacks=[],
)
usage_dict: dict[str, LLMUsage | None] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and scratchpad.action:
if scratchpad.action.action_name.lower() != "final answer":
raise AgentMaxIterationError(app_config.agent.max_iteration)
# get llm usage
if "usage" in usage_dict:
if usage_dict["usage"] is not None:
increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought or "",
observation="",
answer=scratchpad.agent_response or "",
messages_ids=[],
llm_usage=usage_dict["usage"],
)
if not scratchpad.is_final():
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ""
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
except TypeError:
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
# action is tool call, invoke tool
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
tool_instances=tool_instances,
message_file_ids=message_file_ids,
trace_manager=trace_manager,
)
scratchpad.observation = tool_invoke_response
scratchpad.agent_response = tool_invoke_response
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought or "",
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict["usage"],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
),
system_fingerprint="",
)
# save agent thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name="",
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[],
)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: Mapping[str, Tool],
message_file_ids: list[str],
trace_manager: TraceQueueManager | None = None,
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
:param tool_instances: tool instances
:param message_file_ids: message file ids
:param trace_manager: trace manager
:return: observation, meta
"""
# action is tool call, invoke tool
tool_call_name = action.action_name
tool_call_args = action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer)
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
)
# publish files
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
return tool_invoke_response, tool_invoke_meta
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
"""
fill in inputs from external data tools
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception:
continue
return instruction
def _init_react_state(self, query):
"""
init agent scratchpad
"""
self._query = query
self._agent_scratchpad = []
self._historic_prompt_messages = self._organize_historic_prompt_messages()
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
"""
message = ""
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
else:
message += f"Thought: {scratchpad.thought}\n\n"
if scratchpad.action_str:
message += f"Action: {scratchpad.action_str}\n\n"
if scratchpad.observation:
message += f"Observation: {scratchpad.observation}\n\n"
return message
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] | None = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory,
).get_prompt()
return historic_prompts

View File

@ -1,118 +0,0 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.utils.encoders import jsonable_encoder
class CotChatAgentRunner(CotAgentRunner):
def _organize_system_prompt(self) -> SystemPromptMessage:
"""
Organize system prompt
"""
assert self.app_config.agent
assert self.app_config.agent.prompt
prompt_entity = self.app_config.agent.prompt
if not prompt_entity:
raise ValueError("Agent prompt configuration is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content="")
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_message.content += f"Observation: {unit.observation}\n\n"
assistant_messages = [assistant_message]
# query messages
query_messages = self._organize_user_query(self._query, [])
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages(
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
)
messages = [
system_message,
*historic_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content="continue"),
]
else:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
messages = [system_message, *historic_messages, *query_messages]
# join all messages
return messages

View File

@ -1,87 +0,0 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
class CotCompletionAgentRunner(CotAgentRunner):
def _organize_instruction_prompt(self) -> str:
"""
Organize instruction prompt
"""
if self.app_config.agent is None:
raise ValueError("Agent configuration is not set")
prompt_entity = self.app_config.agent.prompt
if prompt_entity is None:
raise ValueError("prompt entity is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
"""
Organize historic prompt
"""
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
historic_prompt = ""
for message in historic_prompt_messages:
if isinstance(message, UserPromptMessage):
historic_prompt += f"Question: {message.content}\n\n"
elif isinstance(message, AssistantPromptMessage):
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
return historic_prompt
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize prompt messages
"""
# organize system prompt
system_prompt = self._organize_instruction_prompt()
# organize historic prompt messages
historic_prompt = self._organize_historic_prompt()
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ""
for unit in agent_scratchpad or []:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
else:
assistant_prompt += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_prompt += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_prompt += f"Observation: {unit.observation}\n\n"
# query messages
query_prompt = f"Question: {self._query}"
# join all messages
prompt = (
system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt)
.replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)]

View File

@ -1,3 +1,5 @@
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Union
@ -92,3 +94,96 @@ class AgentInvokeMessage(ToolInvokeMessage):
"""
pass
class ExecutionContext(BaseModel):
"""Execution context containing trace and audit information.
This context carries all the IDs and metadata that are not part of
the core business logic but needed for tracing, auditing, and
correlation purposes.
"""
user_id: str | None = None
app_id: str | None = None
conversation_id: str | None = None
message_id: str | None = None
tenant_id: str | None = None
@classmethod
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
"""Create a minimal context with only essential fields."""
return cls(user_id=user_id)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for passing to legacy code."""
return {
"user_id": self.user_id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"message_id": self.message_id,
"tenant_id": self.tenant_id,
}
def with_updates(self, **kwargs) -> "ExecutionContext":
"""Create a new context with updated fields."""
data = self.to_dict()
data.update(kwargs)
return ExecutionContext(
user_id=data.get("user_id"),
app_id=data.get("app_id"),
conversation_id=data.get("conversation_id"),
message_id=data.get("message_id"),
tenant_id=data.get("tenant_id"),
)
class AgentLog(BaseModel):
"""
Agent Log.
"""
class LogType(StrEnum):
"""Type of agent log entry."""
ROUND = "round" # A complete iteration round
THOUGHT = "thought" # LLM thinking/reasoning
TOOL_CALL = "tool_call" # Tool invocation
class LogMetadata(StrEnum):
STARTED_AT = "started_at"
FINISHED_AT = "finished_at"
ELAPSED_TIME = "elapsed_time"
TOTAL_PRICE = "total_price"
TOTAL_TOKENS = "total_tokens"
PROVIDER = "provider"
CURRENCY = "currency"
LLM_USAGE = "llm_usage"
ICON = "icon"
ICON_DARK = "icon_dark"
class LogStatus(StrEnum):
START = "start"
ERROR = "error"
SUCCESS = "success"
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log")
label: str = Field(..., description="The label of the log")
log_type: LogType = Field(..., description="The type of the log")
parent_id: str | None = Field(default=None, description="Leave empty for root log")
error: str | None = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
class AgentResult(BaseModel):
"""
Agent execution result.
"""
text: str = Field(default="", description="The generated text")
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
usage: Any | None = Field(default=None, description="LLM usage statistics")
finish_reason: str | None = Field(default=None, description="Reason for completion")

View File

@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
),
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)

View File

@ -0,0 +1,55 @@
# Agent Patterns
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
## Overview
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
## Key Features
- **Dual strategies**
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
- **Explicit or auto selection**
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
- **Unified execution contract**
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
- **Tool handling and hooks**
- Tools convert to `PromptMessageTool` objects before invocation.
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
- **File-aware arguments**
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
- **ReAct prompt shaping**
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
- **Observability and accounting**
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
## Architecture
```
agent/patterns/
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
├── function_call.py # Native function-calling loop with tool execution
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
└── strategy_factory.py # Strategy selection by model features or explicit override
```
## Usage
- For auto-selection:
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
- For explicit behavior:
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
## Integration Points
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.

View File

@ -0,0 +1,19 @@
"""Agent patterns module.
This module provides different strategies for agent execution:
- FunctionCallStrategy: Uses native function/tool calling
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
- StrategyFactory: Factory for creating strategies based on model features
"""
from .base import AgentPattern
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
from .strategy_factory import StrategyFactory
__all__ = [
"AgentPattern",
"FunctionCallStrategy",
"ReActStrategy",
"StrategyFactory",
]

View File

@ -0,0 +1,474 @@
"""Base class for agent strategies."""
from __future__ import annotations
import json
import re
import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
# Type alias for tool invoke hook
# Returns: (response_content, message_file_ids, tool_invoke_meta)
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
class AgentPattern(ABC):
"""Base class for agent execution strategies."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
):
"""Initialize the agent strategy."""
self.model_instance = model_instance
self.tools = tools
self.context = context
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
self.workflow_call_depth = workflow_call_depth
self.files: list[File] = files
self.tool_invoke_hook = tool_invoke_hook
@abstractmethod
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the agent strategy."""
pass
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
"""Accumulate LLM usage statistics."""
if not total_usage.get("usage"):
# Create a copy to avoid modifying the original
total_usage["usage"] = LLMUsage(
prompt_tokens=delta_usage.prompt_tokens,
prompt_unit_price=delta_usage.prompt_unit_price,
prompt_price_unit=delta_usage.prompt_price_unit,
prompt_price=delta_usage.prompt_price,
completion_tokens=delta_usage.completion_tokens,
completion_unit_price=delta_usage.completion_unit_price,
completion_price_unit=delta_usage.completion_price_unit,
completion_price=delta_usage.completion_price,
total_tokens=delta_usage.total_tokens,
total_price=delta_usage.total_price,
currency=delta_usage.currency,
latency=delta_usage.latency,
)
else:
current: LLMUsage = total_usage["usage"]
current.prompt_tokens += delta_usage.prompt_tokens
current.completion_tokens += delta_usage.completion_tokens
current.total_tokens += delta_usage.total_tokens
current.prompt_price += delta_usage.prompt_price
current.completion_price += delta_usage.completion_price
current.total_price += delta_usage.total_price
def _extract_content(self, content: Any) -> str:
"""Extract text content from message content."""
if isinstance(content, list):
# Content items are PromptMessageContentUnionTypes
text_parts = []
for c in content:
# Check if it's a TextPromptMessageContent (which has data attribute)
if isinstance(c, TextPromptMessageContent):
text_parts.append(c.data)
return "".join(text_parts)
return str(content)
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
"""Check if chunk contains tool calls."""
# LLMResultChunk always has delta attribute
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
def _has_tool_calls_result(self, result: LLMResult) -> bool:
"""Check if result contains tool calls (non-streaming)."""
# LLMResult always has message attribute
return bool(result.message and result.message.tool_calls)
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from streaming chunk."""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
if chunk.delta.message and chunk.delta.message.tool_calls:
for tool_call in chunk.delta.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from non-streaming result."""
tool_calls = []
if result.message and result.message.tool_calls:
for tool_call in result.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_text_from_message(self, message: PromptMessage) -> str:
"""Extract text content from a prompt message."""
# PromptMessage always has content attribute
content = message.content
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract text from content list
text_parts = []
for item in content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return " ".join(text_parts)
return ""
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
"""Get metadata for a tool including provider and icon info."""
from core.tools.tool_manager import ToolManager
metadata: dict[AgentLog.LogMetadata, Any] = {}
if tool_instance.entity and tool_instance.entity.identity:
identity = tool_instance.entity.identity
if identity.provider:
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
# Get icon using ToolManager for proper URL generation
tenant_id = self.context.tenant_id
if tenant_id and identity.provider:
try:
provider_type = tool_instance.tool_provider_type()
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
if isinstance(icon, str):
metadata[AgentLog.LogMetadata.ICON] = icon
elif isinstance(icon, dict):
# Handle icon dict with background/content or light/dark variants
metadata[AgentLog.LogMetadata.ICON] = icon
except Exception:
# Fallback to identity.icon if ToolManager fails
if identity.icon:
metadata[AgentLog.LogMetadata.ICON] = identity.icon
elif identity.icon:
metadata[AgentLog.LogMetadata.ICON] = identity.icon
return metadata
def _create_log(
self,
label: str,
log_type: AgentLog.LogType,
status: AgentLog.LogStatus,
data: dict[str, Any] | None = None,
parent_id: str | None = None,
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
) -> AgentLog:
"""Create a new AgentLog with standard metadata."""
metadata: dict[AgentLog.LogMetadata, Any] = {
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
}
if extra_metadata:
metadata.update(extra_metadata)
return AgentLog(
label=label,
log_type=log_type,
status=status,
data=data or {},
parent_id=parent_id,
metadata=metadata,
)
def _finish_log(
self,
log: AgentLog,
data: dict[str, Any] | None = None,
usage: LLMUsage | None = None,
) -> AgentLog:
"""Finish an AgentLog by updating its status and metadata."""
log.status = AgentLog.LogStatus.SUCCESS
if data is not None:
log.data = data
# Calculate elapsed time
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
finished_at = time.perf_counter()
# Update metadata
log.metadata = {
**log.metadata,
AgentLog.LogMetadata.FINISHED_AT: finished_at,
# Calculate elapsed time in seconds
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
}
# Add usage information if provided
if usage:
log.metadata.update(
{
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
AgentLog.LogMetadata.CURRENCY: usage.currency,
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
AgentLog.LogMetadata.LLM_USAGE: usage,
}
)
return log
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
"""
Replace file references in tool arguments with actual File objects.
Args:
tool_args: Dictionary of tool arguments
Returns:
Updated tool arguments with file references replaced
"""
# Process each argument in the dictionary
processed_args: dict[str, Any] = {}
for key, value in tool_args.items():
processed_args[key] = self._process_file_reference(value)
return processed_args
def _process_file_reference(self, data: Any) -> Any:
"""
Recursively process data to replace file references.
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
Args:
data: The data to process (can be dict, list, str, or other types)
Returns:
Processed data with file references replaced
"""
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
if isinstance(data, dict):
# Process dictionary recursively
return {key: self._process_file_reference(value) for key, value in data.items()}
elif isinstance(data, list):
# Process list recursively
return [self._process_file_reference(item) for item in data]
elif isinstance(data, str):
# Check for single file pattern [File: file_id]
single_match = single_file_pattern.match(data.strip())
if single_match:
file_id = single_match.group(1).strip()
# Find the file in self.files
for file in self.files:
if file.id and str(file.id) == file_id:
return file
# If file not found, return original value
return data
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
multiple_match = multiple_files_pattern.match(data.strip())
if multiple_match:
file_ids_str = multiple_match.group(1).strip()
# Split by comma and strip whitespace
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
# Find all matching files
matched_files: list[File] = []
for file_id in file_ids:
for file in self.files:
if file.id and str(file.id) == file_id:
matched_files.append(file)
break
# Return list of files if any were found, otherwise return original
return matched_files or data
return data
else:
# Return other types as-is
return data
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
"""Create a text chunk for streaming."""
return LLMResultChunk(
model=self.model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
usage=None,
),
system_fingerprint="",
)
def _invoke_tool(
self,
tool_instance: Tool,
tool_args: dict[str, Any],
tool_name: str,
) -> tuple[str, list[File], ToolInvokeMeta | None]:
"""
Invoke a tool and collect its response.
Args:
tool_instance: The tool instance to invoke
tool_args: Tool arguments
tool_name: Name of the tool
Returns:
Tuple of (response_content, tool_files, tool_invoke_meta)
"""
# Process tool_args to replace file references with actual File objects
tool_args = self._replace_file_references(tool_args)
# If a tool invoke hook is set, use it instead of generic_invoke
if self.tool_invoke_hook:
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
# The caller (AgentAppRunner) handles file publishing
return response_content, [], tool_invoke_meta
# Default: use generic_invoke for workflow scenarios
# Import here to avoid circular import
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
tool_response = ToolEngine().generic_invoke(
tool=tool_instance,
tool_parameters=tool_args,
user_id=self.context.user_id or "",
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=self.context.app_id,
conversation_id=self.context.conversation_id,
message_id=self.context.message_id,
)
# Collect response and files
response_content = ""
tool_files: list[File] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
response_content += response.message.text
elif response.type == ToolInvokeMessage.MessageType.LINK:
# Handle link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Link: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
# Handle image URL messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
# Handle image link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
# Handle binary file link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
filename = response.meta.get("filename", "file") if response.meta else "file"
response_content += f"[File: {filename} - {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.JSON:
# Handle JSON messages
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# Handle blob messages - convert to text representation
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
mime_type = (
response.meta.get("mime_type", "application/octet-stream")
if response.meta
else "application/octet-stream"
)
size = len(response.message.blob)
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
# Handle variable messages
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
var_name = response.message.variable_name
var_value = response.message.variable_value
if isinstance(var_value, str):
response_content += var_value
else:
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
# Handle blob chunk messages - these are parts of a larger blob
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
# Handle retriever resources messages
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
response_content += response.message.context
elif response.type == ToolInvokeMessage.MessageType.FILE:
# Extract file from meta
if response.meta and "file" in response.meta:
file = response.meta["file"]
if isinstance(file, File):
# Check if file is for model or tool output
if response.meta.get("target") == "self":
# File is for model - add to files for next prompt
self.files.append(file)
response_content += f"File '{file.filename}' has been loaded into your context."
else:
# File is tool output
tool_files.append(file)
return response_content, tool_files, None
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
"""Find a tool instance by its name."""
for tool in self.tools:
if tool.entity.identity.name == tool_name:
return tool
return None
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
"""Convert tools to prompt message format."""
prompt_tools: list[PromptMessageTool] = []
for tool in self.tools:
prompt_tools.append(tool.to_prompt_message_tool())
return prompt_tools
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
"""Initialize usage tracking with empty usage if not set."""
if "usage" not in llm_usage or llm_usage["usage"] is None:
llm_usage["usage"] = LLMUsage.empty_usage()

View File

@ -0,0 +1,299 @@
"""Function Call strategy implementation."""
import json
from collections.abc import Generator
from typing import Any, Union
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
)
from core.tools.entities.tool_entities import ToolInvokeMeta
from .base import AgentPattern
class FunctionCallStrategy(AgentPattern):
"""Function Call strategy using model's native tool calling capability."""
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the function call agent strategy."""
# Convert tools to prompt format
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
# Initialize tracking
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
function_call_state: bool = True
total_usage: dict[str, LLMUsage | None] = {"usage": None}
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
while function_call_state and iteration_step <= max_iterations:
function_call_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
yield round_log
# On last iteration, remove tools to force final answer
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, LLMUsage | None] = {"usage": None}
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
tools=current_tools,
stop=stop,
stream=stream,
user=self.context.user_id,
callbacks=[],
)
# Process response
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log
)
messages.append(self._create_assistant_message(response_content, tool_calls))
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update final text if no tool calls (this is likely the final answer)
if not tool_calls:
final_text = response_content
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Process tool calls
tool_outputs: dict[str, str] = {}
if tool_calls:
function_call_state = True
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
yield self._finish_log(
round_log,
data={
"llm_result": response_content,
"tool_calls": [
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
]
if tool_calls
else [],
"final_answer": final_text if not function_call_state else None,
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text,
files=output_files,
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
finish_reason=finish_reason,
)
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, LLMUsage | None],
start_log: AgentLog,
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
]:
"""Handle LLM response chunks and extract tool calls and content.
Returns a tuple of (tool_calls, response_content, finish_reason).
"""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
response_content: str = ""
finish_reason: str | None = None
if isinstance(chunks, Generator):
# Streaming response
for chunk in chunks:
# Extract tool calls
if self._has_tool_calls(chunk):
tool_calls.extend(self._extract_tool_calls(chunk))
# Extract content
if chunk.delta.message and chunk.delta.message.content:
response_content += self._extract_content(chunk.delta.message.content)
# Track usage
if chunk.delta.usage:
self._accumulate_usage(llm_usage, chunk.delta.usage)
# Capture finish reason
if chunk.delta.finish_reason:
finish_reason = chunk.delta.finish_reason
yield chunk
else:
# Non-streaming response
result: LLMResult = chunks
if self._has_tool_calls_result(result):
tool_calls.extend(self._extract_tool_calls_result(result))
if result.message and result.message.content:
response_content += self._extract_content(result.message.content)
if result.usage:
self._accumulate_usage(llm_usage, result.usage)
# Convert to streaming format
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
)
yield self._finish_log(
start_log,
data={
"result": response_content,
},
usage=llm_usage.get("usage"),
)
return tool_calls, response_content, finish_reason
def _create_assistant_message(
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
) -> AssistantPromptMessage:
"""Create assistant message with tool calls."""
if tool_calls is None:
return AssistantPromptMessage(content=content)
return AssistantPromptMessage(
content=content or "",
tool_calls=[
AssistantPromptMessage.ToolCall(
id=tc[0],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
)
for tc in tool_calls
],
)
def _handle_tool_call(
self,
tool_name: str,
tool_args: dict[str, Any],
tool_call_id: str,
messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
"""Handle a single tool call and return response with files and meta."""
# Find tool
tool_instance = self._find_tool_by_name(tool_name)
if not tool_instance:
raise ValueError(f"Tool {tool_name} not found")
# Get tool metadata (provider, icon, etc.)
tool_metadata = self._get_tool_metadata(tool_instance)
# Create tool call log
tool_call_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
extra_metadata=tool_metadata,
)
yield tool_call_log
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
yield self._finish_log(
tool_call_log,
data={
**tool_call_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
final_content = response_content or "Tool executed successfully"
# Add tool response to messages
messages.append(
ToolPromptMessage(
content=final_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return response_content, tool_files, tool_invoke_meta
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_call_log.status = AgentLog.LogStatus.ERROR
tool_call_log.error = error_message
tool_call_log.data = {
**tool_call_log.data,
"error": error_message,
}
yield tool_call_log
# Add error message to conversation
error_content = f"Tool execution failed: {error_message}"
messages.append(
ToolPromptMessage(
content=error_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return error_content, [], None

View File

@ -0,0 +1,418 @@
"""ReAct strategy implementation."""
from __future__ import annotations
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Union
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
SystemPromptMessage,
)
from .base import AgentPattern, ToolInvokeHook
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class ReActStrategy(AgentPattern):
"""ReAct strategy using reasoning and acting approach."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
):
"""Initialize the ReAct strategy with instruction support."""
super().__init__(
model_instance=model_instance,
tools=tools,
context=context,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
files=files,
tool_invoke_hook=tool_invoke_hook,
)
self.instruction = instruction
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the ReAct agent strategy."""
# Initialize tracking
agent_scratchpad: list[AgentScratchpadUnit] = []
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
react_state: bool = True
total_usage: dict[str, Any] = {"usage": None}
output_files: list[File] = [] # Track files produced by tools
final_text: str = ""
finish_reason: str | None = None
# Add "Observation" to stop sequences
if "Observation" not in stop:
stop = stop.copy()
stop.append("Observation")
while react_state and iteration_step <= max_iterations:
react_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
yield round_log
# Build prompt with/without tools based on iteration
include_tools = iteration_step < max_iterations
current_messages = self._build_prompt_with_react_format(
prompt_messages, agent_scratchpad, include_tools, self.instruction
)
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, Any] = {"usage": None}
# Use current messages directly (files are handled by base class if needed)
messages_to_use = current_messages
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages_to_use,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=self.context.user_id or "",
callbacks=[],
)
# Process response
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log, current_messages
)
agent_scratchpad.append(scratchpad)
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Check if we have an action to execute
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
react_state = True
# Execute tool
observation, tool_files = yield from self._handle_tool_call(
scratchpad.action, current_messages, round_log
)
scratchpad.observation = observation
# Track files produced by tools
output_files.extend(tool_files)
# Add observation to scratchpad for display
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
else:
# Extract final answer
if scratchpad.action and scratchpad.action.action_input:
final_answer = scratchpad.action.action_input
if isinstance(final_answer, dict):
final_answer = json.dumps(final_answer, ensure_ascii=False)
final_text = str(final_answer)
elif scratchpad.thought:
# If no action but we have thought, use thought as final answer
final_text = scratchpad.thought
yield self._finish_log(
round_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
"observation": scratchpad.observation or None,
"final_answer": final_text if not react_state else None,
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
)
def _build_prompt_with_react_format(
self,
original_messages: list[PromptMessage],
agent_scratchpad: list[AgentScratchpadUnit],
include_tools: bool = True,
instruction: str = "",
) -> list[PromptMessage]:
"""Build prompt messages with ReAct format."""
# Copy messages to avoid modifying original
messages = list(original_messages)
# Find and update the system prompt that should already exist
system_prompt_found = False
for i, msg in enumerate(messages):
if isinstance(msg, SystemPromptMessage):
system_prompt_found = True
# The system prompt from frontend already has the template, just replace placeholders
# Format tools
tools_str = ""
tool_names = []
if include_tools and self.tools:
# Convert tools to prompt message tools format
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
tool_names = [tool.name for tool in prompt_tools]
# Format tools as JSON for comprehensive information
from core.model_runtime.utils.encoders import jsonable_encoder
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
else:
tools_str = "No tools available"
tool_names_str = ""
# Replace placeholders in the existing system prompt
updated_content = msg.content
assert isinstance(updated_content, str)
updated_content = updated_content.replace("{{instruction}}", instruction)
updated_content = updated_content.replace("{{tools}}", tools_str)
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
# Create new SystemPromptMessage with updated content
messages[i] = SystemPromptMessage(content=updated_content)
break
# If no system prompt found, that's unexpected but add scratchpad anyway
if not system_prompt_found:
# This shouldn't happen if frontend is working correctly
pass
# Format agent scratchpad
scratchpad_str = ""
if agent_scratchpad:
scratchpad_parts: list[str] = []
for unit in agent_scratchpad:
if unit.thought:
scratchpad_parts.append(f"Thought: {unit.thought}")
if unit.action_str:
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
if unit.observation:
scratchpad_parts.append(f"Observation: {unit.observation}")
scratchpad_str = "\n".join(scratchpad_parts)
# If there's a scratchpad, append it to the last message
if scratchpad_str:
messages.append(AssistantPromptMessage(content=scratchpad_str))
return messages
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, Any],
model_log: AgentLog,
current_messages: list[PromptMessage],
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[AgentScratchpadUnit, str | None],
]:
"""Handle LLM response chunks and extract action/thought.
Returns a tuple of (scratchpad_unit, finish_reason).
"""
usage_dict: dict[str, Any] = {}
# Convert non-streaming to streaming format if needed
if isinstance(chunks, LLMResult):
# Create a generator from the LLMResult
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=chunks.model,
prompt_messages=chunks.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=chunks.message,
usage=chunks.usage,
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
),
system_fingerprint=chunks.system_fingerprint or "",
)
streaming_chunks = result_to_chunks()
else:
streaming_chunks = chunks
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
# Initialize scratchpad unit
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
finish_reason: str | None = None
# Process chunks
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
# Action detected
action_str = json.dumps(chunk.model_dump())
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
scratchpad.action_str = action_str
scratchpad.action = chunk
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
else:
# Text chunk
chunk_text = str(chunk)
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
scratchpad.thought = (scratchpad.thought or "") + chunk_text
yield self._create_text_chunk(chunk_text, current_messages)
# Update usage
if usage_dict.get("usage"):
if llm_usage.get("usage"):
self._accumulate_usage(llm_usage, usage_dict["usage"])
else:
llm_usage["usage"] = usage_dict["usage"]
# Clean up thought
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
# Finish model log
yield self._finish_log(
model_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
},
usage=llm_usage.get("usage"),
)
return scratchpad, finish_reason
def _handle_tool_call(
self,
action: AgentScratchpadUnit.Action,
prompt_messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
"""Handle tool call and return observation with files."""
tool_name = action.action_name
tool_args: dict[str, Any] | str = action.action_input
# Find tool instance first to get metadata
tool_instance = self._find_tool_by_name(tool_name)
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
# Start tool log with tool metadata
tool_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
extra_metadata=tool_metadata,
)
yield tool_log
if not tool_instance:
# Finish tool log with error
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"error": f"Tool {tool_name} not found",
},
)
return f"Tool {tool_name} not found", []
# Ensure tool_args is a dict
tool_args_dict: dict[str, Any]
if isinstance(tool_args, str):
try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
tool_args_dict = {"input": tool_args}
elif not isinstance(tool_args, dict):
tool_args_dict = {"input": str(tool_args)}
else:
tool_args_dict = tool_args
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
# Finish tool log
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
return response_content or "Tool executed successfully", tool_files
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_log.status = AgentLog.LogStatus.ERROR
tool_log.error = error_message
tool_log.data = {
**tool_log.data,
"error": error_message,
}
yield tool_log
return f"Tool execution failed: {error_message}", []

View File

@ -0,0 +1,107 @@
"""Strategy factory for creating agent strategies."""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.agent.entities import AgentEntity, ExecutionContext
from core.file.models import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature
from .base import AgentPattern, ToolInvokeHook
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class StrategyFactory:
"""Factory for creating agent strategies based on model features."""
# Tool calling related features
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
@staticmethod
def create_strategy(
model_features: list[ModelFeature],
model_instance: ModelInstance,
context: ExecutionContext,
tools: list[Tool],
files: list[File],
max_iterations: int = 10,
workflow_call_depth: int = 0,
agent_strategy: AgentEntity.Strategy | None = None,
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
) -> AgentPattern:
"""
Create an appropriate strategy based on model features.
Args:
model_features: List of model features/capabilities
model_instance: Model instance to use
context: Execution context containing trace/audit information
tools: Available tools
files: Available files
max_iterations: Maximum iterations for the strategy
workflow_call_depth: Depth of workflow calls
agent_strategy: Optional explicit strategy override
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
instruction: Optional instruction for ReAct strategy
Returns:
AgentStrategy instance
"""
# If explicit strategy is provided and it's Function Calling, try to use it if supported
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
# Fallback to ReAct if FC is requested but not supported
# If explicit strategy is Chain of Thought (ReAct)
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Default auto-selection logic
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
# Model supports native function calling
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
else:
# Use ReAct strategy for models without function calling
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)

View File

@ -1,4 +1,3 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: str | None = Field(default=None)
json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: str | None) -> str | None:
def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
try:
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,

View File

@ -4,6 +4,7 @@ import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Union
@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.queue_entities import (
ChunkType,
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
@ -70,13 +72,122 @@ from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@dataclass
class StreamEventBuffer:
"""
Buffer for recording stream events in order to reconstruct the generation sequence.
Records the exact order of text chunks, thoughts, and tool calls as they stream.
"""
# Accumulated reasoning content (each thought block is a separate element)
reasoning_content: list[str] = field(default_factory=list)
# Current reasoning buffer (accumulates until we see a different event type)
_current_reasoning: str = ""
# Tool calls with their details
tool_calls: list[dict] = field(default_factory=list)
# Tool call ID to index mapping for updating results
_tool_call_id_map: dict[str, int] = field(default_factory=dict)
# Sequence of events in stream order
sequence: list[dict] = field(default_factory=list)
# Current position in answer text
_content_position: int = 0
# Track last event type to detect transitions
_last_event_type: str | None = None
def _flush_current_reasoning(self) -> None:
"""Flush accumulated reasoning to the list and add to sequence."""
if self._current_reasoning.strip():
self.reasoning_content.append(self._current_reasoning.strip())
self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1})
self._current_reasoning = ""
def record_text_chunk(self, text: str) -> None:
"""Record a text chunk event."""
if not text:
return
# Flush any pending reasoning first
if self._last_event_type == "thought":
self._flush_current_reasoning()
text_len = len(text)
start_pos = self._content_position
# If last event was also content, extend it; otherwise create new
if self.sequence and self.sequence[-1].get("type") == "content":
self.sequence[-1]["end"] = start_pos + text_len
else:
self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len})
self._content_position += text_len
self._last_event_type = "content"
def record_thought_chunk(self, text: str) -> None:
"""Record a thought/reasoning chunk event."""
if not text:
return
# Accumulate thought content
self._current_reasoning += text
self._last_event_type = "thought"
def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None:
"""Record a tool call event."""
if not tool_call_id:
return
# Flush any pending reasoning first
if self._last_event_type == "thought":
self._flush_current_reasoning()
# Check if this tool call already exists (we might get multiple chunks)
if tool_call_id in self._tool_call_id_map:
idx = self._tool_call_id_map[tool_call_id]
# Update arguments if provided
if tool_arguments:
self.tool_calls[idx]["arguments"] = tool_arguments
else:
# New tool call
tool_call = {
"id": tool_call_id or "",
"name": tool_name or "",
"arguments": tool_arguments or "",
"result": "",
"elapsed_time": None,
}
self.tool_calls.append(tool_call)
idx = len(self.tool_calls) - 1
self._tool_call_id_map[tool_call_id] = idx
self.sequence.append({"type": "tool_call", "index": idx})
self._last_event_type = "tool_call"
def record_tool_result(self, tool_call_id: str, result: str, tool_elapsed_time: float | None = None) -> None:
"""Record a tool result event (update existing tool call)."""
if not tool_call_id:
return
if tool_call_id in self._tool_call_id_map:
idx = self._tool_call_id_map[tool_call_id]
self.tool_calls[idx]["result"] = result
self.tool_calls[idx]["elapsed_time"] = tool_elapsed_time
def finalize(self) -> None:
"""Finalize the buffer, flushing any pending data."""
if self._last_event_type == "thought":
self._flush_current_reasoning()
def has_data(self) -> bool:
"""Check if there's any meaningful data recorded."""
return bool(self.reasoning_content or self.tool_calls or self.sequence)
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
@ -144,6 +255,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
# Stream event buffer for recording generation sequence
self._stream_buffer = StreamEventBuffer()
self._seed_graph_runtime_state_from_queue_manager()
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@ -358,25 +471,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if node_finish_resp:
yield node_finish_resp
# For ANSWER nodes, check if we need to send a message_replace event
# Only send if the final output differs from the accumulated task_state.answer
# This happens when variables were updated by variable_assigner during workflow execution
if event.node_type == NodeType.ANSWER and event.outputs:
final_answer = event.outputs.get("answer")
if final_answer is not None and final_answer != self._task_state.answer:
logger.info(
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
final_answer,
self._task_state.answer,
)
# Update the task state answer
self._task_state.answer = str(final_answer)
# Send message_replace event to update the UI
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=str(final_answer),
reason="variable_update",
)
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
@ -402,7 +496,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
"""Handle text chunk events and record to stream buffer for sequence reconstruction."""
delta_text = event.text
if delta_text is None:
return
@ -424,9 +518,52 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
tool_call = event.tool_call
tool_result = event.tool_result
tool_payload = tool_call or tool_result
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
tool_files = tool_result.files if tool_result else []
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
tool_icon = tool_payload.icon if tool_payload else None
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
# Record stream event based on chunk type
chunk_type = event.chunk_type or ChunkType.TEXT
match chunk_type:
case ChunkType.TEXT:
self._stream_buffer.record_text_chunk(delta_text)
self._task_state.answer += delta_text
case ChunkType.THOUGHT:
# Reasoning should not be part of final answer text
self._stream_buffer.record_thought_chunk(delta_text)
case ChunkType.TOOL_CALL:
self._stream_buffer.record_tool_call(
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
)
case ChunkType.TOOL_RESULT:
self._stream_buffer.record_tool_result(
tool_call_id=tool_call_id,
result=delta_text,
tool_elapsed_time=tool_elapsed_time,
)
self._task_state.answer += delta_text
case _:
pass
yield self._message_cycle_manager.message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
answer=delta_text,
message_id=self._message_id,
from_variable_selector=event.from_variable_selector,
chunk_type=event.chunk_type.value if event.chunk_type else None,
tool_call_id=tool_call_id or None,
tool_name=tool_name or None,
tool_arguments=tool_arguments or None,
tool_files=tool_files,
tool_elapsed_time=tool_elapsed_time,
tool_icon=tool_icon,
tool_icon_dark=tool_icon_dark,
)
def _handle_iteration_start_event(
@ -794,6 +931,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer
answer_text = self._strip_think_blocks(answer_text)
if self._recorded_files:
# Remove markdown image links since we're storing files separately
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
@ -845,6 +983,54 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
# Save generation detail (reasoning/tool calls/sequence) from stream buffer
self._save_generation_detail(session=session, message=message)
@staticmethod
def _strip_think_blocks(text: str) -> str:
"""Remove <think>...</think> blocks (including their content) from text."""
if not text or "<think" not in text.lower():
return text
clean_text = re.sub(r"<think[^>]*>.*?</think>", "", text, flags=re.IGNORECASE | re.DOTALL)
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
return clean_text
def _save_generation_detail(self, *, session: Session, message: Message) -> None:
"""
Save LLM generation detail for Chatflow using stream event buffer.
The buffer records the exact order of events as they streamed,
allowing accurate reconstruction of the generation sequence.
"""
# Finalize the stream buffer to flush any pending data
self._stream_buffer.finalize()
# Only save if there's meaningful data
if not self._stream_buffer.has_data():
return
reasoning_content = self._stream_buffer.reasoning_content
tool_calls = self._stream_buffer.tool_calls
sequence = self._stream_buffer.sequence
# Check if generation detail already exists for this message
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
if existing:
existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None
existing.tool_calls = json.dumps(tool_calls) if tool_calls else None
existing.sequence = json.dumps(sequence) if sequence else None
else:
generation_detail = LLMGenerationDetail(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
message_id=message.id,
reasoning_content=json.dumps(reasoning_content) if reasoning_content else None,
tool_calls=json.dumps(tool_calls) if tool_calls else None,
sequence=json.dumps(sequence) if sequence else None,
)
session.add(generation_detail)
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@ -3,10 +3,8 @@ from typing import cast
from sqlalchemy import select
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.agent_app_runner import AgentAppRunner
from core.agent.entities import AgentEntity
from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
@ -14,8 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError
from extensions.ext_database import db
@ -194,22 +191,7 @@ class AgentChatAppRunner(AppRunner):
raise ValueError("Message not found")
db.session.close()
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
runner_cls = FunctionCallAgentRunner
else:
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
runner = runner_cls(
runner = AgentAppRunner(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
conversation=conversation_result,

View File

@ -1,4 +1,3 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@ -76,12 +75,24 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
invalid_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, dict)
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
]
if invalid_dict_keys:
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
invalid_list_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, list)
and any(isinstance(item, dict) for item in v)
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
]
if invalid_list_dict_keys:
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
return user_inputs
@ -178,12 +189,8 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
if value and not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@ -671,7 +671,7 @@ class WorkflowResponseConverter:
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
id=event.id,
message_id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,

View File

@ -13,6 +13,7 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
ChunkType,
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
@ -483,11 +484,33 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if delta_text is None:
return
tool_call = event.tool_call
tool_result = event.tool_result
tool_payload = tool_call or tool_result
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
tool_arguments = tool_call.arguments if tool_call else None
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
tool_files = tool_result.files if tool_result else []
tool_icon = tool_payload.icon if tool_payload else None
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
# only publish tts message at text chunk streaming
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
yield self._text_chunk_to_stream_response(
text=delta_text,
from_variable_selector=event.from_variable_selector,
chunk_type=event.chunk_type,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
tool_files=tool_files,
tool_elapsed_time=tool_elapsed_time,
tool_icon=tool_icon,
tool_icon_dark=tool_icon_dark,
)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
@ -650,16 +673,61 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
session.add(workflow_app_log)
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: list[str] | None = None
self,
text: str,
from_variable_selector: list[str] | None = None,
chunk_type: ChunkType | None = None,
tool_call_id: str | None = None,
tool_name: str | None = None,
tool_arguments: str | None = None,
tool_files: list[str] | None = None,
tool_error: str | None = None,
tool_elapsed_time: float | None = None,
tool_icon: str | dict | None = None,
tool_icon_dark: str | dict | None = None,
) -> TextChunkStreamResponse:
"""
Handle completed event.
:param text: text
:return:
"""
from core.app.entities.task_entities import ChunkType as ResponseChunkType
response_chunk_type = ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT
data = TextChunkStreamResponse.Data(
text=text,
from_variable_selector=from_variable_selector,
chunk_type=response_chunk_type,
)
if response_chunk_type == ResponseChunkType.TOOL_CALL:
data = data.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
elif response_chunk_type == ResponseChunkType.TOOL_RESULT:
data = data.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_files": tool_files,
"tool_error": tool_error,
"tool_elapsed_time": tool_elapsed_time,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id,
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
data=data,
)
return response

View File

@ -463,12 +463,20 @@ class WorkflowBasedAppRunner:
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
from core.app.entities.queue_entities import ChunkType as QueueChunkType
if event.is_final and not event.chunk:
return
self._publish_event(
QueueTextChunkEvent(
text=event.chunk,
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
chunk_type=QueueChunkType(event.chunk_type.value),
tool_call=event.tool_call,
tool_result=event.tool_result,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):

View File

@ -9,7 +9,6 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAp
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
from models.enums import CreatorUserRole
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
@ -81,11 +80,6 @@ class InvokeFrom(StrEnum):
return "dev"
def to_creator_user_role(self) -> CreatorUserRole:
if self in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}:
return CreatorUserRole.ACCOUNT
return CreatorUserRole.END_USER
class ModelConfigWithCredentialsEntity(BaseModel):
"""

View File

@ -0,0 +1,70 @@
"""
LLM Generation Detail entities.
Defines the structure for storing and transmitting LLM generation details
including reasoning content, tool calls, and their sequence.
"""
from typing import Literal
from pydantic import BaseModel, Field
class ContentSegment(BaseModel):
"""Represents a content segment in the generation sequence."""
type: Literal["content"] = "content"
start: int = Field(..., description="Start position in the text")
end: int = Field(..., description="End position in the text")
class ReasoningSegment(BaseModel):
"""Represents a reasoning segment in the generation sequence."""
type: Literal["reasoning"] = "reasoning"
index: int = Field(..., description="Index into reasoning_content array")
class ToolCallSegment(BaseModel):
"""Represents a tool call segment in the generation sequence."""
type: Literal["tool_call"] = "tool_call"
index: int = Field(..., description="Index into tool_calls array")
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
class ToolCallDetail(BaseModel):
"""Represents a tool call with its arguments and result."""
id: str = Field(default="", description="Unique identifier for the tool call")
name: str = Field(..., description="Name of the tool")
arguments: str = Field(default="", description="JSON string of tool arguments")
result: str = Field(default="", description="Result from the tool execution")
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
class LLMGenerationDetailData(BaseModel):
"""
Domain model for LLM generation detail.
Contains the structured data for reasoning content, tool calls,
and their display sequence.
"""
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
def is_empty(self) -> bool:
"""Check if there's any meaningful generation detail."""
return not self.reasoning_content and not self.tool_calls
def to_response_dict(self) -> dict:
"""Convert to dictionary for API response."""
return {
"reasoning_content": self.reasoning_content,
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
"sequence": [seg.model_dump() for seg in self.sequence],
}

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@ -177,6 +177,17 @@ class QueueLoopCompletedEvent(AppQueueEvent):
error: str | None = None
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class QueueTextChunkEvent(AppQueueEvent):
"""
QueueTextChunkEvent entity
@ -191,6 +202,16 @@ class QueueTextChunkEvent(AppQueueEvent):
in_loop_id: str | None = None
"""loop id if node is in loop"""
# Extended fields for Agent/Tool streaming
chunk_type: ChunkType = ChunkType.TEXT
"""type of the chunk"""
# Tool streaming payloads
tool_call: ToolCall | None = None
"""structured tool call info"""
tool_result: ToolResult | None = None
"""structured tool result info"""
class QueueAgentMessageEvent(AppQueueEvent):
"""

View File

@ -113,6 +113,38 @@ class MessageStreamResponse(StreamResponse):
answer: str
from_variable_selector: list[str] | None = None
# Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import)
chunk_type: str | None = None
"""type of the chunk: text, tool_call, tool_result, thought"""
# Tool call fields (when chunk_type == "tool_call")
tool_call_id: str | None = None
"""unique identifier for this tool call"""
tool_name: str | None = None
"""name of the tool being called"""
tool_arguments: str | None = None
"""accumulated tool arguments JSON"""
# Tool result fields (when chunk_type == "tool_result")
tool_files: list[str] | None = None
"""file IDs produced by tool"""
tool_error: str | None = None
"""error message if tool failed"""
tool_elapsed_time: float | None = None
"""elapsed time spent executing the tool"""
tool_icon: str | dict | None = None
"""icon of the tool"""
tool_icon_dark: str | dict | None = None
"""dark theme icon of the tool"""
def model_dump(self, *args, **kwargs) -> dict[str, object]:
kwargs.setdefault("exclude_none", True)
return super().model_dump(*args, **kwargs)
def model_dump_json(self, *args, **kwargs) -> str:
kwargs.setdefault("exclude_none", True)
return super().model_dump_json(*args, **kwargs)
class MessageAudioStreamResponse(StreamResponse):
"""
@ -582,6 +614,17 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
data: Data
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class TextChunkStreamResponse(StreamResponse):
"""
TextChunkStreamResponse entity
@ -595,6 +638,36 @@ class TextChunkStreamResponse(StreamResponse):
text: str
from_variable_selector: list[str] | None = None
# Extended fields for Agent/Tool streaming
chunk_type: ChunkType = ChunkType.TEXT
"""type of the chunk"""
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call_id: str | None = None
"""unique identifier for this tool call"""
tool_name: str | None = None
"""name of the tool being called"""
tool_arguments: str | None = None
"""accumulated tool arguments JSON"""
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_files: list[str] | None = None
"""file IDs produced by tool"""
tool_error: str | None = None
"""error message if tool failed"""
# Tool elapsed time fields (when chunk_type == TOOL_RESULT)
tool_elapsed_time: float | None = None
"""elapsed time spent executing the tool"""
def model_dump(self, *args, **kwargs) -> dict[str, object]:
kwargs.setdefault("exclude_none", True)
return super().model_dump(*args, **kwargs)
def model_dump_json(self, *args, **kwargs) -> str:
kwargs.setdefault("exclude_none", True)
return super().model_dump_json(*args, **kwargs)
event: StreamEvent = StreamEvent.TEXT_CHUNK
data: Data
@ -743,7 +816,7 @@ class AgentLogStreamResponse(StreamResponse):
"""
node_execution_id: str
id: str
message_id: str
label: str
parent_id: str | None = None
error: str | None = None

View File

@ -1,4 +1,5 @@
import logging
import re
import time
from collections.abc import Generator
from threading import Thread
@ -58,7 +59,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought
from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought
logger = logging.getLogger(__name__)
@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
@ -409,11 +412,136 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
)
# Save LLM generation detail if there's reasoning_content
self._save_generation_detail(session=session, message=message, llm_result=llm_result)
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None:
"""
Save LLM generation detail for Completion/Chat/Agent-Chat applications.
For Agent-Chat, also merges MessageAgentThought records.
"""
import json
reasoning_list: list[str] = []
tool_calls_list: list[dict] = []
sequence: list[dict] = []
answer = message.answer or ""
# Check if this is Agent-Chat mode by looking for agent thoughts
agent_thoughts = (
session.query(MessageAgentThought)
.filter_by(message_id=message.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
if agent_thoughts:
# Agent-Chat mode: merge MessageAgentThought records
content_pos = 0
cleaned_answer_parts: list[str] = []
for thought in agent_thoughts:
# Add thought/reasoning
if thought.thought:
reasoning_text = thought.thought
if "<think" in reasoning_text.lower():
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
if extracted_reasoning:
reasoning_text = extracted_reasoning
thought.thought = clean_text or extracted_reasoning
reasoning_list.append(reasoning_text)
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
# Add tool calls
if thought.tool:
tool_calls_list.append(
{
"name": thought.tool,
"arguments": thought.tool_input or "",
"result": thought.observation or "",
}
)
sequence.append({"type": "tool_call", "index": len(tool_calls_list) - 1})
# Add answer content if present
if thought.answer:
content_text = thought.answer
if "<think" in content_text.lower():
clean_answer, extracted_reasoning = self._split_reasoning_from_answer(content_text)
if extracted_reasoning:
reasoning_list.append(extracted_reasoning)
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
content_text = clean_answer
thought.answer = clean_answer or content_text
if content_text:
start = content_pos
end = content_pos + len(content_text)
sequence.append({"type": "content", "start": start, "end": end})
content_pos = end
cleaned_answer_parts.append(content_text)
if cleaned_answer_parts:
merged_answer = "".join(cleaned_answer_parts)
message.answer = merged_answer
llm_result.message.content = merged_answer
else:
# Completion/Chat mode: use reasoning_content from llm_result
reasoning_content = llm_result.reasoning_content
if not reasoning_content and answer:
# Extract reasoning from <think> blocks and clean the final answer
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
if reasoning_content:
answer = clean_answer
llm_result.message.content = clean_answer
llm_result.reasoning_content = reasoning_content
message.answer = clean_answer
if reasoning_content:
reasoning_list = [reasoning_content]
# Content comes first, then reasoning
if answer:
sequence.append({"type": "content", "start": 0, "end": len(answer)})
sequence.append({"type": "reasoning", "index": 0})
# Only save if there's meaningful generation detail
if not reasoning_list and not tool_calls_list:
return
# Check if generation detail already exists
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
if existing:
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
existing.sequence = json.dumps(sequence) if sequence else None
else:
generation_detail = LLMGenerationDetail(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
message_id=message.id,
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
sequence=json.dumps(sequence) if sequence else None,
)
session.add(generation_detail)
@classmethod
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
"""
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
"""
matches = cls._THINK_PATTERN.findall(text)
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
clean_text = cls._THINK_PATTERN.sub("", text)
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
return clean_text, reasoning_content or ""
def _handle_stop(self, event: QueueStopEvent):
"""
Handle stop.

View File

@ -232,15 +232,31 @@ class MessageCycleManager:
answer: str,
message_id: str,
from_variable_selector: list[str] | None = None,
chunk_type: str | None = None,
tool_call_id: str | None = None,
tool_name: str | None = None,
tool_arguments: str | None = None,
tool_files: list[str] | None = None,
tool_error: str | None = None,
tool_elapsed_time: float | None = None,
tool_icon: str | dict | None = None,
tool_icon_dark: str | dict | None = None,
event_type: StreamEvent | None = None,
) -> MessageStreamResponse:
"""
Message to stream response.
:param answer: answer
:param message_id: message id
:param from_variable_selector: from variable selector
:param chunk_type: type of the chunk (text, function_call, tool_result, thought)
:param tool_call_id: unique identifier for this tool call
:param tool_name: name of the tool being called
:param tool_arguments: accumulated tool arguments JSON
:param tool_files: file IDs produced by tool
:param tool_error: error message if tool failed
:return:
"""
return MessageStreamResponse(
response = MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
@ -248,6 +264,35 @@ class MessageCycleManager:
event=event_type or StreamEvent.MESSAGE,
)
if chunk_type:
response = response.model_copy(update={"chunk_type": chunk_type})
if chunk_type == "tool_call":
response = response.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
elif chunk_type == "tool_result":
response = response.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_files": tool_files,
"tool_error": tool_error,
"tool_elapsed_time": tool_elapsed_time,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
return response
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
"""
Message replace to stream response.

View File

@ -5,7 +5,6 @@ from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
@ -37,7 +36,9 @@ class DatasetIndexToolCallbackHandler:
content=query,
source="app",
source_app_id=self._app_id,
created_by_role=self._invoke_from.to_creator_user_role(),
created_by_role=(
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
),
created_by=self._user_id,
)
@ -88,6 +89,8 @@ class DatasetIndexToolCallbackHandler:
# TODO(-LAN-): Improve type check
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
"""Handle return_retriever_resource_info."""
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
)

View File

@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator
class PreviewDetail(BaseModel):
content: str
summary: str | None = None
child_chunks: list[str] | None = None

View File

@ -311,14 +311,18 @@ class IndexingRunner:
qa_preview_texts: list[QAPreviewDetail] = []
total_segments = 0
# doc_form represents the segmentation method (general, parent-child, QA)
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# one extract_setting is one source document
for extract_setting in extract_settings:
# extract
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
)
# Extract document content
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
# Cleaning and segmentation
documents = index_processor.transform(
text_docs,
current_user=None,
@ -361,6 +365,12 @@ class IndexingRunner:
if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
# Generate summary preview
summary_index_setting = tmp_processing_rule["summary_index_setting"] if "summary_index_setting" in tmp_processing_rule else None
if summary_index_setting and summary_index_setting.get('enable') and preview_texts:
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(

View File

@ -434,3 +434,6 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
DEFAULT_GENERATOR_SUMMARY_PROMPT = """
You are a helpful assistant that summarizes long pieces of text into concise summaries. Given the following text, generate a brief summary that captures the main points and key information. The summary should be clear, concise, and written in complete sentences. """

View File

@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
if not super().is_empty() and not self.tool_calls:
return False
return True
return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)

View File

@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=trace_api.SpanKind.INTERNAL,
kind=span_data.span_kind,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,

View File

@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel, Field
@ -34,3 +34,4 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")

View File

@ -392,6 +392,69 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
segment_file_map = {}
segment_summary_map = {} # Map segment_id to summary content
summary_segment_ids = set() # Track segments retrieved via summary
with Session(bind=db.engine, expire_on_commit=False) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# Handle parent-child documents
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
# Check if this is a summary document
is_summary = document.metadata.get("is_summary", False)
if is_summary:
# For summary documents, find the original chunk via original_chunk_id
original_chunk_id = document.metadata.get("original_chunk_id")
if not original_chunk_id:
continue
segment_id = original_chunk_id
# Track that this segment was retrieved via summary
summary_segment_ids.add(segment_id)
else:
# For normal documents, find by child chunk index_node_id
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
if not child_chunk:
continue
segment_id = child_chunk.segment_id
if not segment_id:
continue
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.first()
)
valid_dataset_documents = {}
image_doc_ids: list[Any] = []
@ -507,7 +570,47 @@ class RetrievalService:
max_score = max(
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
)
segment = session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
# Check if this is a summary document
is_summary = document.metadata.get("is_summary", False)
if is_summary:
# For summary documents, find the original chunk via original_chunk_id
original_chunk_id = document.metadata.get("original_chunk_id")
if not original_chunk_id:
continue
# Track that this segment was retrieved via summary
summary_segment_ids.add(original_chunk_id)
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == original_chunk_id,
)
segment = session.scalar(document_segment_stmt)
else:
# For normal documents, find by index_node_id
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = session.scalar(document_segment_stmt)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
map_detail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
@ -542,6 +645,23 @@ class RetrievalService:
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
# Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids:
from models.dataset import DocumentSegmentSummary
summaries = (
session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(summary_segment_ids),
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
)
.all()
)
for summary in summaries:
if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content
result: list[RetrievalSegments] = []
for record in records:
# Extract segment
@ -576,9 +696,16 @@ class RetrievalService:
else None
)
# Extract summary if this segment was retrieved via summary
summary_content = segment_summary_map.get(segment.id)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks_list, score=score, files=files
segment=segment,
child_chunks=child_chunks_list,
score=score,
files=files,
summary=summary_content
)
result.append(retrieval_segment)

View File

@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel):
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None
summary: str | None = None # Summary content if retrieved via summary index

View File

@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
@ -45,6 +46,15 @@ class BaseIndexProcessor(ABC):
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def generate_summary_preview(self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict) -> list[PreviewDetail]:
"""
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
The summary can be stored in a new attribute, e.g., summary.
This method should be implemented by subclasses.
"""
raise NotImplementedError
@abstractmethod
def load(
self,

View File

@ -1,9 +1,13 @@
"""Paragraph index processor."""
import logging
import uuid
from collections.abc import Mapping
from typing import Any
logger = logging.getLogger(__name__)
from core.entities.knowledge_entities import PreviewDetail
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
@ -17,12 +21,19 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from extensions.ext_database import db
from libs import helper
from models.account import Account
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_runtime.entities.message_entities import UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.model_manager import ModelInstance
class ParagraphIndexProcessor(BaseIndexProcessor):
@ -108,6 +119,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
if node_ids:
@ -227,3 +261,70 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
}
else:
raise ValueError("Chunks is not a list")
def generate_summary_preview(self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict) -> list[PreviewDetail]:
"""
For each segment, concurrently call generate_summary to generate a summary
and write it to the summary attribute of PreviewDetail.
"""
import concurrent.futures
from flask import current_app
# Capture Flask app context for worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def process(preview: PreviewDetail) -> None:
"""Generate summary for a single preview item."""
try:
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
except Exception as e:
logger.error(f"Failed to generate summary for preview: {str(e)}")
# Don't fail the entire preview if summary generation fails
preview.summary = None
with concurrent.futures.ThreadPoolExecutor() as executor:
list(executor.map(process, preview_texts))
return preview_texts
@staticmethod
def generate_summary(tenant_id: str, text: str, summary_index_setting: dict = None) -> str:
"""
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt.
"""
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
model_name = summary_index_setting.get("model_name")
model_provider_name = summary_index_setting.get("model_provider_name")
summary_prompt = summary_index_setting.get("summary_prompt")
# Import default summary prompt
if not summary_prompt:
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
prompt = f"{summary_prompt}\n{text}"
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(tenant_id, model_provider_name, ModelType.LLM)
model_instance = ModelInstance(provider_model_bundle, model_name)
prompt_messages = [UserPromptMessage(content=prompt)]
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={},
stream=False
)
return getattr(result.message, "content", "")

View File

@ -25,6 +25,7 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from services.summary_index_service import SummaryIndexService
class ParentChildIndexProcessor(BaseIndexProcessor):
@ -135,6 +136,29 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")

View File

@ -25,9 +25,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@ -144,6 +145,30 @@ class QAIndexProcessor(BaseIndexProcessor):
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Note: qa_model doesn't generate summaries, but we clean them for completeness
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)

View File

@ -63,7 +63,6 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
@ -177,13 +176,13 @@ class DatasetRetrieval:
)
all_documents = []
creator_user_role = invoke_from.to_creator_user_role()
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
all_documents = self.single_retrieve(
app_id,
tenant_id,
user_id,
creator_user_role,
user_from,
query,
available_datasets,
model_instance,
@ -198,7 +197,7 @@ class DatasetRetrieval:
app_id,
tenant_id,
user_id,
creator_user_role,
user_from,
available_datasets,
query,
retrieve_config.top_k or 0,
@ -335,7 +334,7 @@ class DatasetRetrieval:
app_id: str,
tenant_id: str,
user_id: str,
creator_user_role: CreatorUserRole,
user_from: str,
query: str,
available_datasets: list,
model_instance: ModelInstance,
@ -445,7 +444,7 @@ class DatasetRetrieval:
weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
)
self._on_query(query, None, [dataset_id], app_id, creator_user_role, user_id)
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
if results:
thread = threading.Thread(
@ -467,7 +466,7 @@ class DatasetRetrieval:
app_id: str,
tenant_id: str,
user_id: str,
creator_user_role: CreatorUserRole,
user_from: str,
available_datasets: list,
query: str | None,
top_k: int,
@ -585,7 +584,7 @@ class DatasetRetrieval:
if thread_exceptions:
raise thread_exceptions[0]
self._on_query(query, attachment_ids, dataset_ids, app_id, creator_user_role, user_id)
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
# add thread to call _on_retrieval_end
@ -734,7 +733,7 @@ class DatasetRetrieval:
attachment_ids: list[str] | None,
dataset_ids: list[str],
app_id: str,
creator_user_role: CreatorUserRole,
user_from: str,
user_id: str,
):
"""
@ -756,7 +755,7 @@ class DatasetRetrieval:
content=json.dumps(contents),
source="app",
source_app_id=app_id,
created_by_role=creator_user_role,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)

View File

@ -29,6 +29,7 @@ from models import (
Account,
CreatorUserRole,
EndUser,
LLMGenerationDetail,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
)
@ -457,6 +458,113 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
session.merge(db_model)
session.flush()
# Save LLMGenerationDetail for LLM nodes with successful execution
if (
domain_model.node_type == NodeType.LLM
and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
and domain_model.outputs is not None
):
self._save_llm_generation_detail(session, domain_model)
def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None:
"""
Save LLM generation detail for LLM nodes.
Extracts reasoning_content, tool_calls, and sequence from outputs and metadata.
"""
outputs = execution.outputs or {}
metadata = execution.metadata or {}
reasoning_list = self._extract_reasoning(outputs)
tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG))
if not reasoning_list and not tool_calls_list:
return
sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list)
self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence)
def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]:
"""Extract reasoning_content as a clean list of non-empty strings."""
reasoning_content = outputs.get("reasoning_content")
if isinstance(reasoning_content, str):
trimmed = reasoning_content.strip()
return [trimmed] if trimmed else []
if isinstance(reasoning_content, list):
return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()]
return []
def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]:
"""Extract tool call records from agent logs."""
if not agent_log or not isinstance(agent_log, list):
return []
tool_calls: list[dict[str, str]] = []
for log in agent_log:
log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {})
tool_name = log_data.get("tool_name")
if tool_name and str(tool_name).strip():
tool_calls.append(
{
"id": log_data.get("tool_call_id", ""),
"name": tool_name,
"arguments": json.dumps(log_data.get("tool_args", {})),
"result": str(log_data.get("output", "")),
}
)
return tool_calls
def _build_generation_sequence(
self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]]
) -> list[dict[str, Any]]:
"""Build a simple content/reasoning/tool_call sequence."""
sequence: list[dict[str, Any]] = []
if text:
sequence.append({"type": "content", "start": 0, "end": len(text)})
for index in range(len(reasoning_list)):
sequence.append({"type": "reasoning", "index": index})
for index in range(len(tool_calls_list)):
sequence.append({"type": "tool_call", "index": index})
return sequence
def _upsert_generation_detail(
self,
session,
execution: WorkflowNodeExecution,
reasoning_list: list[str],
tool_calls_list: list[dict[str, str]],
sequence: list[dict[str, Any]],
) -> None:
"""Insert or update LLMGenerationDetail with serialized fields."""
existing = (
session.query(LLMGenerationDetail)
.filter_by(
workflow_run_id=execution.workflow_execution_id,
node_id=execution.node_id,
)
.first()
)
reasoning_json = json.dumps(reasoning_list) if reasoning_list else None
tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None
sequence_json = json.dumps(sequence) if sequence else None
if existing:
existing.reasoning_content = reasoning_json
existing.tool_calls = tool_calls_json
existing.sequence = sequence_json
return
generation_detail = LLMGenerationDetail(
tenant_id=self._tenant_id,
app_id=self._app_id,
workflow_run_id=execution.workflow_execution_id,
node_id=execution.node_id,
reasoning_content=reasoning_json,
tool_calls=tool_calls_json,
sequence=sequence_json,
)
session.add(generation_detail)
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.model import File
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
ToolEntity,
@ -154,6 +155,60 @@ class Tool(ABC):
return parameters
def to_prompt_message_tool(self) -> PromptMessageTool:
message_tool = PromptMessageTool(
name=self.entity.identity.name,
description=self.entity.description.llm if self.entity.description else "",
parameters={
"type": "object",
"properties": {},
"required": [],
},
)
parameters = self.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
# Determine the description based on parameter type
if parameter.type == ToolParameter.ToolParameterType.FILE:
file_format_desc = " Input the file id with format: [File: file_id]."
else:
file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. "
message_tool.parameters["properties"][parameter.name] = {
"type": "string",
"description": (parameter.llm_description or "") + file_format_desc,
}
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
message_tool.parameters["required"].append(parameter.name)
return message_tool
def create_image_message(
self,
image: str,

View File

@ -1,11 +1,16 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"ToolCall",
"ToolCallResult",
"ToolResult",
"ToolResultStatus",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -0,0 +1,39 @@
from enum import StrEnum
from pydantic import BaseModel, Field
from core.file import File
class ToolResultStatus(StrEnum):
SUCCESS = "success"
ERROR = "error"
class ToolCall(BaseModel):
id: str | None = Field(default=None, description="Unique identifier for this tool call")
name: str | None = Field(default=None, description="Name of the tool being called")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
icon: str | dict | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
class ToolResult(BaseModel):
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
name: str | None = Field(default=None, description="Name of the tool")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[str] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
icon: str | dict | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
class ToolCallResult(BaseModel):
id: str | None = Field(default=None, description="Identifier for the tool call")
name: str | None = Field(default=None, description="Name of the tool")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[File] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")

View File

@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum):
def is_ended(self) -> bool:
return self in _END_STATE
@classmethod
def ended_values(cls) -> list[str]:
return [status.value for status in _END_STATE]
_END_STATE = frozenset(
[
@ -247,6 +251,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
LLM_TRACE = "llm_trace"
COMPLETED_REASON = "completed_reason" # completed reason for loop node

View File

@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
@ -321,11 +327,24 @@ class ResponseStreamCoordinator:
selector: Sequence[str],
chunk: str,
is_final: bool = False,
chunk_type: ChunkType = ChunkType.TEXT,
tool_call: ToolCall | None = None,
tool_result: ToolResult | None = None,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
For selectors with special prefixes (sys, env, conversation), we use the
active response node's information since these are not actual node IDs.
Args:
node_id: The node ID to attribute the event to
execution_id: The execution ID for this node
selector: The variable selector
chunk: The chunk content
is_final: Whether this is the final chunk
chunk_type: The semantic type of the chunk being streamed
tool_call: Structured data for tool_call chunks
tool_result: Structured data for tool_result chunks
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self._graph.nodes and self._active_session:
@ -338,6 +357,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
# Standard case: selector refers to an actual node
@ -349,6 +371,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
@ -356,6 +381,8 @@ class ResponseStreamCoordinator:
Handles both regular node selectors and special system selectors (sys, env, conversation).
For special selectors, we attribute the output to the active response node.
For object-type variables, automatically streams all child fields that have stream events.
"""
events: list[NodeRunStreamChunkEvent] = []
source_selector_prefix = segment.selector[0] if segment.selector else ""
@ -364,60 +391,81 @@ class ResponseStreamCoordinator:
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
active_session = self._active_session
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
execution_id = self._get_or_create_execution_id(output_node_id)
# Stream all available chunks
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=event.selector, # Keep original selector
chunk=event.chunk,
is_final=event.is_final,
)
events.append(updated_event)
else:
# Regular node selector - use event as is
events.append(event)
# Check if there's a direct stream for this selector
has_direct_stream = (
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
)
# Check if this is the last chunk by looking ahead
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
if stream_targets:
all_complete = True
for target_selector in stream_targets:
while self._has_unread_stream(target_selector):
if event := self._pop_stream_chunk(target_selector):
events.append(
self._rewrite_stream_event(
event=event,
output_node_id=output_node_id,
execution_id=execution_id,
special_selector=bool(special_selector),
)
)
if not self._is_stream_closed(target_selector):
all_complete = False
is_complete = all_complete
# Fallback: check if scalar value exists in variable pool
if not is_complete and not has_direct_stream:
if value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session
and self._active_session.index == len(self._active_session.template.segments) - 1
)
)
is_complete = True
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
)
)
is_complete = True
return events, is_complete
def _rewrite_stream_event(
self,
event: NodeRunStreamChunkEvent,
output_node_id: str,
execution_id: str,
special_selector: bool,
) -> NodeRunStreamChunkEvent:
"""Rewrite event to attribute to active response node when selector is special."""
if not special_selector:
return event
return self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=event.chunk_type,
tool_call=event.tool_call,
tool_result=event.tool_result,
)
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self._active_session is not None
@ -513,6 +561,36 @@ class ResponseStreamCoordinator:
# ============= Internal Stream Management Methods =============
def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]:
"""Find all child stream selectors that are descendants of the parent selector.
For example, if parent_selector is ['llm', 'generation'], this will find:
- ['llm', 'generation', 'content']
- ['llm', 'generation', 'tool_calls']
- ['llm', 'generation', 'tool_results']
- ['llm', 'generation', 'thought']
Args:
parent_selector: The parent selector to search for children
Returns:
List of child selector tuples found in stream buffers or closed streams
"""
parent_key = tuple(parent_selector)
parent_len = len(parent_key)
child_streams: set[tuple[str, ...]] = set()
# Search in both active buffers and closed streams
all_selectors = set(self._stream_buffers.keys()) | self._closed_streams
for selector_key in all_selectors:
# Check if this selector is a direct child of the parent
# Direct child means: len(child) == len(parent) + 1 and child starts with parent
if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key:
child_streams.add(selector_key)
return sorted(child_streams)
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.

View File

@ -36,6 +36,7 @@ from .loop import (
# Node events
from .node import (
ChunkType,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
@ -44,10 +45,13 @@ from .node import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
__all__ = [
"BaseGraphEvent",
"ChunkType",
"GraphEngineEvent",
"GraphNodeEventBase",
"GraphRunAbortedEvent",
@ -73,4 +77,6 @@ __all__ = [
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
"ToolCall",
"ToolResult",
]

View File

@ -1,10 +1,11 @@
from collections.abc import Sequence
from datetime import datetime
from enum import StrEnum
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@ -21,13 +22,39 @@ class NodeRunStartedEvent(GraphNodeEventBase):
provider_id: str = ""
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class NodeRunStreamChunkEvent(GraphNodeEventBase):
# Spec-compliant fields
"""Stream chunk event for workflow node execution."""
# Base fields
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call: ToolCall | None = Field(
default=None,
description="structured payload for tool_call chunks",
)
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_result: ToolResult | None = Field(
default=None,
description="structured payload for tool_result chunks",
)
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):

View File

@ -13,16 +13,21 @@ from .loop import (
LoopSucceededEvent,
)
from .node import (
ChunkType,
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
__all__ = [
"AgentLogEvent",
"ChunkType",
"IterationFailedEvent",
"IterationNextEvent",
"IterationStartedEvent",
@ -39,4 +44,7 @@ __all__ = [
"RunRetryEvent",
"StreamChunkEvent",
"StreamCompletedEvent",
"ThoughtChunkEvent",
"ToolCallChunkEvent",
"ToolResultChunkEvent",
]

View File

@ -1,11 +1,13 @@
from collections.abc import Sequence
from datetime import datetime
from enum import StrEnum
from pydantic import Field
from core.file import File
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.node_events import NodeRunResult
@ -32,13 +34,60 @@ class RunRetryEvent(NodeEventBase):
start_at: datetime = Field(..., description="Retry start time")
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class StreamChunkEvent(NodeEventBase):
# Spec-compliant fields
"""Base stream chunk event - normal text streaming output."""
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
class ToolCallChunkEvent(StreamChunkEvent):
"""Tool call streaming event - tool call arguments streaming output."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
class ToolResultChunkEvent(StreamChunkEvent):
"""Tool result event - tool execution result."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
class ThoughtStartChunkEvent(StreamChunkEvent):
"""Agent thought start streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_START, frozen=True)
class ThoughtEndChunkEvent(StreamChunkEvent):
"""Agent thought end streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_END, frozen=True)
class ThoughtChunkEvent(StreamChunkEvent):
"""Agent thought streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
class StreamCompletedEvent(NodeEventBase):

View File

@ -48,6 +48,9 @@ from core.workflow.node_events import (
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
from core.workflow.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
@ -564,6 +567,8 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
return NodeRunStreamChunkEvent(
id=self.execution_id,
node_id=self._node_id,
@ -571,6 +576,60 @@ class Node(Generic[NodeDataT]):
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType(event.chunk_type.value),
tool_call=event.tool_call,
tool_result=event.tool_result,
)
@_dispatch.register
def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_CALL,
tool_call=event.tool_call,
)
@_dispatch.register
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.entities import ToolResult, ToolResultStatus
from core.workflow.graph_events import ChunkType
tool_result = event.tool_result or ToolResult()
status: ToolResultStatus = tool_result.status or ToolResultStatus.SUCCESS
tool_result = tool_result.model_copy(
update={"status": status, "files": tool_result.files or []},
)
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_RESULT,
tool_result=tool_result,
)
@_dispatch.register
def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.THOUGHT,
)
@_dispatch.register

View File

@ -62,6 +62,21 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
inputs = {"variable_selector": variable_selector}
process_data = {"documents": value if isinstance(value, list) else [value]}
# Ensure storage_key is loaded for File objects
files_to_check = value if isinstance(value, list) else [value]
files_needing_storage_key = [
f for f in files_to_check
if isinstance(f, File) and not f.storage_key and f.related_id
]
if files_needing_storage_key:
from factories.file_factory import StorageKeyLoader
from extensions.ext_database import db
from sqlalchemy.orm import Session
with Session(bind=db.engine) as session:
storage_key_loader = StorageKeyLoader(session, tenant_id=self.tenant_id)
storage_key_loader.load_storage_keys(files_needing_storage_key)
try:
if isinstance(value, list):
extracted_text_list = list(map(_extract_text_from_file, value))
@ -415,6 +430,15 @@ def _download_file_content(file: File) -> bytes:
response.raise_for_status()
return response.content
else:
# Check if storage_key is set
if not file.storage_key:
raise FileDownloadError(f"File storage_key is missing for file: {file.filename}")
# Check if file exists before downloading
from extensions.ext_storage import storage
if not storage.exists(file.storage_key):
raise FileDownloadError(f"File not found in storage: {file.storage_key}")
return file_manager.download(file)
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e

View File

@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData):
type: str = "knowledge-index"
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: dict | None = None

View File

@ -1,9 +1,11 @@
import concurrent.futures
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
from .entities import KnowledgeIndexNodeData
from .exc import (
@ -67,7 +71,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
# index knowledge
try:
if is_preview:
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
# Preview mode: generate summaries for chunks directly without saving to database
# Format preview and generate summaries on-the-fly
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
# or fallback to dataset if not available in node_data
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure, chunks, dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@ -163,6 +178,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
db.session.commit()
# Generate summary index if enabled
self._handle_summary_index_generation(dataset, document, variable_pool)
return {
"dataset_id": ds_id_value,
"dataset_name": dataset_name_value,
@ -173,9 +191,269 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
"display_status": "completed",
}
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
def _handle_summary_index_generation(
self,
dataset: Dataset,
document: Document,
variable_pool: VariablePool,
) -> None:
"""
Handle summary index generation based on mode (debug/preview or production).
Args:
dataset: Dataset containing the document
document: Document to generate summaries for
variable_pool: Variable pool to check invoke_from
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
return
# Check if summary index is enabled
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
return
# Skip qa_model documents
if document.doc_form == "qa_model":
return
# Determine if in preview/debug mode
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
# Determine if only parent chunks should be processed
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
if is_preview:
try:
# Query segments that need summary generation
query = db.session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
segments = query.all()
if not segments:
logger.info(f"No segments found for document {document.id}")
return
# Filter segments based on mode
segments_to_process = []
for segment in segments:
# Skip if summary already exists
existing_summary = (
db.session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
.first()
)
if existing_summary:
continue
# For parent-child mode, all segments are parent chunks, so process all
segments_to_process.append(segment)
if not segments_to_process:
logger.info(f"No segments need summary generation for document {document.id}")
return
# Use ThreadPoolExecutor for concurrent generation
flask_app = current_app._get_current_object() # type: ignore
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
def process_segment(segment: DocumentSegment) -> None:
"""Process a single segment in a thread with Flask app context."""
with flask_app.app_context():
try:
SummaryIndexService.generate_and_vectorize_summary(
segment, dataset, summary_index_setting
)
except Exception as e:
logger.error(f"Failed to generate summary for segment {segment.id}: {str(e)}")
# Continue processing other segments
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(process_segment, segment) for segment in segments_to_process
]
# Wait for all tasks to complete
concurrent.futures.wait(futures, timeout=300)
logger.info(
f"Successfully generated summary index for {len(segments_to_process)} segments "
f"in document {document.id}"
)
except Exception as e:
logger.exception(f"Failed to generate summary index for document {document.id}: {str(e)}")
# Don't fail the entire indexing process if summary generation fails
else:
# Production mode: asynchronous generation
logger.info(f"Queuing summary index generation task for document {document.id} (production mode)")
try:
generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info(f"Summary index generation task queued for document {document.id}")
except Exception as e:
logger.exception(f"Failed to queue summary index generation task for document {document.id}: {str(e)}")
# Don't fail the entire indexing process if task queuing fails
def _get_preview_output_with_summaries(
self, chunk_structure: str, chunks: Any, dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
This method generates summaries on-the-fly without saving to database.
Args:
chunk_structure: Chunk structure type
chunks: Chunks to generate preview for
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
"""
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks)
preview_output = index_processor.format_preview(chunks)
# Check if summary index is enabled
if indexing_technique != "high_quality":
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):
return preview_output
# Generate summaries for chunks
if "preview" in preview_output and isinstance(preview_output["preview"], list):
chunk_count = len(preview_output["preview"])
logger.info(
f"Generating summaries for {chunk_count} chunks in preview mode "
f"(dataset: {dataset.id})"
)
# Use ParagraphIndexProcessor's generate_summary method
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
# Get Flask app for application context in worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def generate_summary_for_chunk(preview_item: dict) -> None:
"""Generate summary for a single chunk."""
if "content" in preview_item:
try:
# Set Flask application context in worker thread
if flask_app:
with flask_app.app_context():
summary = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
)
if summary:
preview_item["summary"] = summary
else:
# Fallback: try without app context (may fail)
summary = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
)
if summary:
preview_item["summary"] = summary
except Exception as e:
logger.error(f"Failed to generate summary for chunk: {str(e)}")
# Don't fail the entire preview if summary generation fails
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
futures = [
executor.submit(generate_summary_for_chunk, preview_item)
for preview_item in preview_output["preview"]
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
logger.warning(
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s. "
"Cancelling remaining tasks..."
)
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
logger.info(
f"Completed summary generation for preview chunks: {completed_count}/{len(preview_output['preview'])} succeeded"
)
return preview_output
def _get_preview_output(
self, chunk_structure: str, chunks: Any, dataset: Dataset | None = None, variable_pool: VariablePool | None = None
) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
# If dataset is provided, try to enrich preview with summaries
if dataset and variable_pool:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document:
# Query summaries for this document
summaries = (
db.session.query(DocumentSegmentSummary)
.filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
.all()
)
if summaries:
# Create a map of segment content to summary for matching
# Use content matching as chunks in preview might not be indexed yet
summary_by_content = {}
for summary in summaries:
segment = (
db.session.query(DocumentSegment)
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
.first()
)
if segment:
# Normalize content for matching (strip whitespace)
normalized_content = segment.content.strip()
summary_by_content[normalized_content] = summary.summary_content
# Enrich preview with summaries by content matching
if "preview" in preview_output and isinstance(preview_output["preview"], list):
matched_count = 0
for preview_item in preview_output["preview"]:
if "content" in preview_item:
# Normalize content for matching
normalized_chunk_content = preview_item["content"].strip()
if normalized_chunk_content in summary_by_content:
preview_item["summary"] = summary_by_content[normalized_chunk_content]
matched_count += 1
if matched_count > 0:
logger.info(
f"Enriched preview with {matched_count} existing summaries "
f"(dataset: {dataset.id}, document: {document.id})"
)
return preview_output
@classmethod
def version(cls) -> str:

View File

@ -268,7 +268,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
creator_user_role = self.user_from.to_creator_user_role()
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
@ -293,7 +292,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
creator_user_role=creator_user_role,
user_from=self.user_from.value,
query=query,
model_config=model_config,
model_instance=model_instance,
@ -335,7 +334,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
creator_user_role=creator_user_role,
user_from=self.user_from.value,
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,

View File

@ -3,6 +3,7 @@ from .entities import (
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
ToolMetadata,
VisionConfig,
)
from .node import LLMNode
@ -13,5 +14,6 @@ __all__ = [
"LLMNodeCompletionModelPromptTemplate",
"LLMNodeData",
"ModelConfig",
"ToolMetadata",
"VisionConfig",
]

View File

@ -1,10 +1,17 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.model_runtime.entities.llm_entities import LLMUsage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.entities import ToolCall, ToolCallResult
from core.workflow.node_events import AgentLogEvent
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
@ -58,6 +65,268 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: str | None = None
class ToolMetadata(BaseModel):
"""
Tool metadata for LLM node with tool support.
Defines the essential fields needed for tool configuration,
particularly the 'type' field to identify tool provider type.
"""
# Core fields
enabled: bool = True
type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow")
provider_name: str = Field(..., description="Tool provider name/identifier")
tool_name: str = Field(..., description="Tool name")
# Optional fields
plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools")
credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication")
# Configuration fields
parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters")
settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration")
extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description")
class ModelTraceSegment(BaseModel):
"""Model invocation trace segment with token usage and output."""
text: str | None = Field(None, description="Model output text content")
reasoning: str | None = Field(None, description="Reasoning/thought content from model")
tool_calls: list[ToolCall] = Field(default_factory=list, description="Tool calls made by the model")
class ToolTraceSegment(BaseModel):
"""Tool invocation trace segment with call details and result."""
id: str | None = Field(default=None, description="Unique identifier for this tool call")
name: str | None = Field(default=None, description="Name of the tool being called")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
output: str | None = Field(default=None, description="Tool call result")
class LLMTraceSegment(BaseModel):
"""
Streaming trace segment for LLM tool-enabled runs.
Represents alternating model and tool invocations in sequence:
model -> tool -> model -> tool -> ...
Each segment records its execution duration.
"""
type: Literal["model", "tool"]
duration: float = Field(..., description="Execution duration in seconds")
usage: LLMUsage | None = Field(default=None, description="Token usage statistics for this model call")
output: ModelTraceSegment | ToolTraceSegment = Field(..., description="Output of the segment")
# Common metadata for both model and tool segments
provider: str | None = Field(default=None, description="Model or tool provider identifier")
name: str | None = Field(default=None, description="Name of the model or tool")
icon: str | None = Field(default=None, description="Icon for the provider")
icon_dark: str | None = Field(default=None, description="Dark theme icon for the provider")
error: str | None = Field(default=None, description="Error message if segment failed")
status: Literal["success", "error"] | None = Field(default=None, description="Tool execution status")
class LLMGenerationData(BaseModel):
"""Generation data from LLM invocation with tools.
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
- reasoning_contents: [thought1, thought2, ...] - one element per turn
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
"""
text: str = Field(..., description="Accumulated text content from all turns")
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
usage: LLMUsage = Field(..., description="LLM usage statistics")
finish_reason: str | None = Field(None, description="Finish reason from LLM")
files: list[File] = Field(default_factory=list, description="Generated files")
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
class ThinkTagStreamParser:
"""Lightweight state machine to split streaming chunks by <think> tags."""
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
_START_PREFIX = "<think"
_END_PREFIX = "</think"
def __init__(self):
self._buffer = ""
self._in_think = False
@staticmethod
def _suffix_prefix_len(text: str, prefix: str) -> int:
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
max_len = min(len(text), len(prefix) - 1)
for i in range(max_len, 0, -1):
if text[-i:].lower() == prefix[:i].lower():
return i
return 0
def process(self, chunk: str) -> list[tuple[str, str]]:
"""
Split incoming chunk into ('thought' | 'text', content) tuples.
Content excludes the <think> tags themselves and handles split tags across chunks.
"""
parts: list[tuple[str, str]] = []
self._buffer += chunk
while self._buffer:
if self._in_think:
end_match = self._END_PATTERN.search(self._buffer)
if end_match:
thought_text = self._buffer[: end_match.start()]
if thought_text:
parts.append(("thought", thought_text))
parts.append(("thought_end", ""))
self._buffer = self._buffer[end_match.end() :]
self._in_think = False
continue
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("thought", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
start_match = self._START_PATTERN.search(self._buffer)
if start_match:
prefix = self._buffer[: start_match.start()]
if prefix:
parts.append(("text", prefix))
self._buffer = self._buffer[start_match.end() :]
parts.append(("thought_start", ""))
self._in_think = True
continue
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("text", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
cleaned_parts: list[tuple[str, str]] = []
for kind, content in parts:
# Extra safeguard: strip any stray tags that slipped through.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
if content or kind in {"thought_start", "thought_end"}:
cleaned_parts.append((kind, content))
return cleaned_parts
def flush(self) -> list[tuple[str, str]]:
"""Flush remaining buffer when the stream ends."""
if not self._buffer:
return []
kind = "thought" if self._in_think else "text"
content = self._buffer
# Drop dangling partial tags instead of emitting them
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
content = ""
self._buffer = ""
if not content and not self._in_think:
return []
# Strip any complete tags that might still be present.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
result: list[tuple[str, str]] = []
if content:
result.append((kind, content))
if self._in_think:
result.append(("thought_end", ""))
self._in_think = False
return result
class StreamBuffers(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser)
pending_thought: list[str] = Field(default_factory=list)
pending_content: list[str] = Field(default_factory=list)
pending_tool_calls: list[ToolCall] = Field(default_factory=list)
current_turn_reasoning: list[str] = Field(default_factory=list)
reasoning_per_turn: list[str] = Field(default_factory=list)
class TraceState(BaseModel):
trace_segments: list[LLMTraceSegment] = Field(default_factory=list)
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
model_segment_start_time: float | None = Field(default=None, description="Start time for current model segment")
pending_usage: LLMUsage | None = Field(default=None, description="Pending usage for current model segment")
class AggregatedResult(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
text: str = ""
files: list[File] = Field(default_factory=list)
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
finish_reason: str | None = None
class AgentContext(BaseModel):
agent_logs: list[AgentLogEvent] = Field(default_factory=list)
agent_result: AgentResult | None = None
class ToolOutputState(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
stream: StreamBuffers = Field(default_factory=StreamBuffers)
trace: TraceState = Field(default_factory=TraceState)
aggregate: AggregatedResult = Field(default_factory=AggregatedResult)
agent: AgentContext = Field(default_factory=AgentContext)
class ToolLogPayload(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
tool_name: str = ""
tool_call_id: str = ""
tool_args: dict[str, Any] = Field(default_factory=dict)
tool_output: Any = None
tool_error: Any = None
files: list[Any] = Field(default_factory=list)
meta: dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_log(cls, log: AgentLog) -> "ToolLogPayload":
data = log.data or {}
return cls(
tool_name=data.get("tool_name", ""),
tool_call_id=data.get("tool_call_id", ""),
tool_args=data.get("tool_args") or {},
tool_output=data.get("output"),
tool_error=data.get("error"),
files=data.get("files") or [],
meta=data.get("meta") or {},
)
@classmethod
def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload":
return cls(
tool_name=data.get("tool_name", ""),
tool_call_id=data.get("tool_call_id", ""),
tool_args=data.get("tool_args") or {},
tool_output=data.get("output"),
tool_error=data.get("error"),
files=data.get("files") or [],
meta=data.get("meta") or {},
)
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
@ -86,6 +355,10 @@ class LLMNodeData(BaseNodeData):
),
)
# Tool support
tools: Sequence[ToolMetadata] = Field(default_factory=list)
max_iterations: int | None = Field(default=None, description="Maximum number of iterations for the LLM node")
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,3 @@
import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@ -43,25 +42,22 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
# If no value provided, skip further processing for this key
if not value:
continue
if not isinstance(value, dict):
raise ValueError(f"JSON object for '{key}' must be an object")
# Overwrite with normalized dict to ensure downstream consistency
node_inputs[key] = value
# If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
if not value:
continue
try:
json_schema = json.loads(schema)
except json.JSONDecodeError as e:
raise ValueError(f"{schema} must be a valid JSON object")
try:
json_value = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"{value} must be a valid JSON object")
try:
Draft7Validator(json_schema).validate(json_value)
Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = json_value

View File

@ -33,6 +33,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_runtime_state=graph_runtime_state,
)
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
Returns True if this node updates any of the requested conversation variables.
"""
assigned_selector = tuple(self.node_data.assigned_variable_selector)
return assigned_selector in variable_selectors
@classmethod
def version(cls) -> str:
return "1"

View File

@ -19,6 +19,7 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@ -136,13 +137,11 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
node_config = workflow.get_node_config_by_id(node_id)
node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {})
# Get node class
# Get node type
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@ -158,12 +157,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)
node_cls = type(node)
try:
# variable selector to variable mapping

View File

@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery:
imports = [
"tasks.async_workflow_tasks", # trigger workers
"tasks.trigger_processing_tasks", # async trigger processing
"tasks.generate_summary_index_task", # summary index generation
"tasks.regenerate_summary_index_task", # summary index regeneration
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
@ -163,6 +165,13 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
"schedule": crontab(minute="0", hour="2"),
}
if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
# for saas only
imports.append("schedule.clean_workflow_runs_task")
beat_schedule["clean_workflow_runs_task"] = {
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
"schedule": crontab(minute="0", hour="0"),
}
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {

View File

@ -4,6 +4,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
@ -56,6 +57,7 @@ def init_app(app: DifyApp):
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,
clean_workflow_runs,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@ -169,6 +169,7 @@ class MessageDetail(ResponseModel):
status: str
error: str | None = None
parent_message_id: str | None = None
generation_detail: JSONValue | None = Field(default=None, validation_alias="generation_detail_dict")
@field_validator("inputs", mode="before")
@classmethod

View File

@ -39,6 +39,14 @@ dataset_retrieval_model_fields = {
"score_threshold_enabled": fields.Boolean,
"score_threshold": fields.Float,
}
dataset_summary_index_fields = {
"enable": fields.Boolean,
"model_name": fields.String,
"model_provider_name": fields.String,
"summary_prompt": fields.String,
}
external_retrieval_model_fields = {
"top_k": fields.Integer,
"score_threshold": fields.Float,
@ -83,6 +91,7 @@ dataset_detail_fields = {
"embedding_model_provider": fields.String,
"embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"summary_index_setting": fields.Nested(dataset_summary_index_fields),
"tags": fields.List(fields.Nested(tag_fields)),
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),

Some files were not shown because too many files have changed in this diff Show More