Compare commits

..

148 Commits

Author SHA1 Message Date
yyh
92d44c0caa Run pnpm lint:fix 2025-12-17 17:01:21 +08:00
yyh
0478e4cde5 Refactor clear button implementation in type selector 2025-12-17 16:26:23 +08:00
yyh
0f816b52c2 test(web): add AppTypeSelector tests
- Add Jest/RTL coverage for AppTypeSelector, AppTypeLabel, AppTypeIcon\n- Improve clear control accessibility by using a button with aria-label
2025-12-17 16:20:14 +08:00
4fce99379e test(api): add a test for detect_file_encodings (#29778)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-17 14:33:30 +08:00
8d1e36540a fix: detect_file_encodings TypeError: tuple indices must be integers or slices, not str (#29595)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-12-17 13:58:05 +08:00
1d1351393a feat: update RAG recommended plugins hook to accept type parameter (#29735) 2025-12-17 13:48:23 +08:00
44f8915e30 feat: Add Aliyun SLS (Simple Log Service) integration for workflow execution logging (#28986)
Co-authored-by: hieheihei <270985384@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-12-17 13:43:54 +08:00
94a5fd3617 chore: tests for webapp run batch (#29767) 2025-12-17 13:36:50 +08:00
5bb1346da8 chore: tests form add annotation (#29770)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-17 13:36:40 +08:00
a93eecaeee feat: Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. (#29736)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
2025-12-17 11:26:08 +08:00
86131d4bd8 feat: add datasource_parameters handling for API requests (#29757)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-17 10:37:55 +08:00
581b62cf01 feat: add automated tests for pipeline setting (#29478)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-17 10:26:58 +08:00
yyh
91714ee413 chore(web): add some jest tests (#29754) 2025-12-17 10:21:32 +08:00
yyh
232149e63f chore: add tests for config string and dataset card item (#29743) 2025-12-17 10:19:10 +08:00
4a1ddea431 ci: show missing lines in coverage report summary (#29717) 2025-12-17 10:18:41 +08:00
5539bf8788 fix: add Slovenian and Tunisian Arabic translations across multiple language files (#29759) 2025-12-17 10:18:10 +08:00
dda7eb03c9 feat: _truncate_json_primitives support file (#29760) 2025-12-17 08:10:43 +09:00
c2f2be6b08 fix: oxlint no unused expressions (#29675)
Co-authored-by: daniel <daniel@example.com>
2025-12-16 18:00:04 +08:00
b7649f61f8 fix: Login secret text transmission (#29659)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-12-16 16:55:51 +08:00
ae4a9040df Feat/update notion preview (#29345)
Co-authored-by: twwu <twwu@dify.ai>
2025-12-16 16:43:45 +08:00
d2b63df7a1 chore: tests for components in config (#29739) 2025-12-16 16:39:04 +08:00
0749e6e090 test: Stabilize sharded Redis broadcast multi-subscriber test (#29733)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-16 16:35:55 +08:00
yyh
4589157963 test: Add comprehensive Jest test for AppCard component (#29667) 2025-12-16 15:44:51 +08:00
37d4dbeb96 feat: Remove TLS 1.1 from default NGINX protocols (#29728) 2025-12-16 15:39:42 +08:00
yyh
c036a12999 test: add comprehensive unit tests for APIKeyInfoPanel component (#29719) 2025-12-16 15:07:30 +08:00
47cd94ec3e chore: tests for billings (#29720) 2025-12-16 15:06:53 +08:00
e5cf0d0bf6 chore: Disable Swagger UI by default in docker samples (#29723) 2025-12-16 15:01:51 +08:00
yyh
240e1d155a test: add comprehensive tests for CustomizeModal component (#29709) 2025-12-16 14:21:05 +08:00
a915b8a584 revert: "security/fix-swagger-info-leak-m02" (#29721) 2025-12-16 14:19:33 +08:00
yyh
4553e4c12f test: add comprehensive Jest tests for CustomPage and WorkflowOnboardingModal components (#29714) 2025-12-16 14:18:09 +08:00
7695f9151c chore: webhook with bin file should guess mimetype (#29704)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Maries <xh001x@hotmail.com>
2025-12-16 13:34:27 +08:00
bdccbb6e86 feat: add GraphEngine layer node execution hooks (#28583) 2025-12-16 13:26:31 +08:00
c904c58c43 test: add unit tests for DocumentPicker, PreviewDocumentPicker, and R… (#29695)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-16 13:06:50 +08:00
yyh
cb5162f37a test: add comprehensive Jest test for CreateAppTemplateDialog component (#29713) 2025-12-16 12:57:51 +08:00
yyh
eeb5129a17 refactor: create shared react-i18next mock to reduce duplication (#29711) 2025-12-16 12:45:17 +08:00
4cc6652424 feat: VECTOR_STORE supports seekdb (#29658) 2025-12-16 12:35:04 +09:00
yyh
a232da564a test: try to use Anthropic Skills to add tests for web/app/components/apps/ (#29607)
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-16 10:42:34 +08:00
yyh
7fc501915e test: add comprehensive frontend tests for billing plan assets (#29615)
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-15 21:21:34 +08:00
yyh
103a5e0122 test: enhance workflow-log component tests with comprehensive coverage (#29616)
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-15 21:21:14 +08:00
23f75a1185 chore: some tests for configuration components (#29653)
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2025-12-15 21:18:58 +08:00
yyh
7fb68b62b8 test: enhance DebugWithMultipleModel component test coverage (#29657) 2025-12-15 21:17:44 +08:00
dd58d4a38d fix: update chat wrapper components to use min-h instead of h for better responsiveness (#29687) 2025-12-15 21:15:55 +08:00
4bf6c4dafa chore: add online drive metadata source enum (#29674) 2025-12-15 21:13:23 +08:00
187450b875 chore: skip upload_file_id is missing (#29666) 2025-12-15 21:09:53 +08:00
09982a1c95 fix: webhook node output file as file variable (#29621) 2025-12-15 19:55:59 +08:00
a8f3061b3c fix: all upload files are disabled if upload file feature disabled (#29681) 2025-12-15 18:02:34 +08:00
bd7b1fc6fb fix: csv injection in annotations export (#29462)
Co-authored-by: hj24 <huangjian@dify.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-15 17:14:05 +08:00
2bf44057e9 fix: ssrf, add internal ip filter when parse tool schema (#29548)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
2025-12-15 16:28:25 +08:00
2b3c55d95a chore: some billing test (#29648)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Coding On Star <447357187@qq.com>
2025-12-15 16:13:14 +08:00
a951f46a09 chore: tests for annotation (#29664) 2025-12-15 15:38:04 +08:00
d942adf3b2 feat: Enhance Amplitude tracking across various components (#29662)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-15 15:25:10 +08:00
724cd57dbf fix: dos in annotation import (#29470)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-15 15:22:04 +08:00
714b443077 fix: correct i18n SSO translations and fix validation/type issues (#29564)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-15 15:58:33 +09:00
80c74cf725 test: Consolidate API CI test runner (#29440)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-15 13:20:31 +08:00
Sai
1e47ffb50c fix: does not save segment vector when there is no attachment_ids (#29520)
Co-authored-by: sai <>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-15 12:32:52 +08:00
323e0c4d30 chore(i18n): translate i18n files and update type definitions (#29651)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-12-15 12:02:28 +08:00
094f417b32 refactor: admin api using session factory (#29628) 2025-12-15 12:01:41 +08:00
acef56d7fd fix: delete knowledge pipeline but pipeline and workflow don't delete (#29591)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-15 12:00:03 +08:00
8f3fd9a728 perf: commit once (#29590) 2025-12-15 11:40:26 +08:00
1a18012f98 fix(api): use json_repair for conversation title parsing (#29649) 2025-12-15 11:29:28 +08:00
355a2356d4 security/fix-swagger-info-leak-m02 (#29283)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-12-15 11:24:06 +08:00
7fead6a9da Add file upload enabled check and new i18n message (#28946) 2025-12-15 11:18:05 +08:00
63624dece1 fix(workflow): tool plugin output_schema array type not selectable in subsequent nodes (#29035)
Co-authored-by: Claude <noreply@anthropic.com>
2025-12-15 11:17:15 +08:00
02122907e5 fix(api): Populate Missing Attributes For Arize Phoenix Integration (#29526) 2025-12-15 11:15:18 +08:00
yyh
7ee7155fd5 test: add comprehensive Jest tests for ConfirmModal component (#29627)
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-15 11:14:38 +08:00
yyh
d01f2f7436 chore: add AGENTS.md for frontend (#29647) 2025-12-15 10:49:39 +08:00
9d683fd34d chore(deps): bump @hookform/resolvers from 3.10.0 to 5.2.2 in /web (#29442)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-12-15 10:25:28 +08:00
569c593240 feat: Add InterSystems IRIS vector database support (#29480)
Co-authored-by: Tomo Okuyama <tomo.okuyama@intersystems.com>
2025-12-15 10:20:43 +08:00
bb157c93a3 fixes: #28300 Change the Citations banner in dark mode to fully opaque (#28673)
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-12-15 10:17:43 +08:00
b7bdd5920b fix: add secondary text color to plugin task headers (#29529) 2025-12-15 10:13:59 +08:00
470650d1d7 fix: fix delete_account_task not check billing enabled (#29577) 2025-12-15 10:12:35 +08:00
yyh
59137f1d05 fix: show uninstalled plugin nodes in workflow checklist (#29630) 2025-12-15 10:11:23 +08:00
yyh
b8d54d745e fix(ci): use setup-python to avoid 504 errors and use project oxlint config (#29613) 2025-12-15 10:10:51 +08:00
916df2d0f7 fix: fix mime type is none (#29579) 2025-12-15 10:09:11 +08:00
yyh
61199663e7 chore: add anthropic skills for frontend testing (#29608)
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-15 10:05:48 +08:00
3db27c3158 fix(workflow): agent prompt editor canvas not covering full text height (#29623) 2025-12-14 15:53:39 +08:00
3653f54bea fix: validate page_size limit in plugin list and tasks endpoints (#29611) 2025-12-13 22:52:51 +09:00
886ce981cf feat(i18n): add Tunisian Arabic (ar-TN) translation (#29306)
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-13 10:55:04 +08:00
f4c7f98a01 fix: remove unnecessary error log when trigger endpoint returns 404 (#29587)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-12 17:53:40 +08:00
a8613f0233 fix: bump wandb to 0.23.1 urllib3 to 2.6.0 (#29481) 2025-12-12 16:55:19 +08:00
yyh
086ee4c19d test(web): add comprehensive tests for workflow-log component (#29562)
Co-authored-by: Coding On Star <447357187@qq.com>
2025-12-12 16:48:15 +08:00
336bcfbae2 chore: test for app card and no data (#29570)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-12 16:01:58 +08:00
e244856ef1 chore: add test case for download components (#29569)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-12 14:56:08 +08:00
2058186f22 chore: Bump version references to 1.11.1 (#29568) 2025-12-12 14:42:25 +08:00
bece2f101c fix: return None from retrieve_tokens when access_token is empty (#29516) 2025-12-12 13:49:11 +08:00
04d09c2d77 fix: hit-test failed when attachment id is not exist (#29563) 2025-12-12 13:45:00 +08:00
db42f467c8 fix: docx extractor external image failed (#29558) 2025-12-12 13:41:51 +08:00
ac40309850 perf: optimize save_document_with_dataset_id (#29550) 2025-12-12 13:14:45 +08:00
12e39365fa perf(core/rag): optimize Excel extractor performance and memory usage (#29551)
Co-authored-by: 01393547 <nieronghua@sf-express.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-12 12:15:03 +08:00
d48300d08c fix: remove validate=True to fix flask-restx AttributeError (#29552)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-12 12:01:20 +08:00
761f8c8043 fix: set response content type with charset in helper (#29534) 2025-12-12 11:50:35 +08:00
05f63c88c6 feat: integrate Amplitude API key into layout and provider components (#29546)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-12 11:49:12 +08:00
8daf9ce98d test(trigger): add container integration tests for trigger (#29527)
Signed-off-by: Stream <Stream_2@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-12 11:31:34 +08:00
61ee1b9094 fix: truncate auto-populated description to prevent 400-char limit error (#28681) 2025-12-12 11:19:53 +08:00
87c4b4c576 fix: nextjs security update (#29545) 2025-12-12 11:05:48 +08:00
193c8e2362 fix: fix available_credentials is empty (#29521) 2025-12-12 09:51:55 +08:00
4d57460356 fix: upgrade react and react-dom to 19.2.3,fix cve errors (#29532) 2025-12-12 09:38:32 +08:00
063b39ada5 fix: conversation rename payload validation (#29510) 2025-12-11 18:05:41 +09:00
6419ce02c7 test: add testcase for config prompt components (#29491)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-11 16:56:20 +08:00
yyh
1a877bb4d0 chore: add .nvmrc for Node 22 alignment (#29495) 2025-12-11 16:29:30 +08:00
281e9d4f51 fix: chat api in explore page reject blank conversation id (#29500) 2025-12-11 16:26:42 +08:00
a195b410d1 chore(i18n): translate i18n files and update type definitions (#29499)
Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2025-12-11 16:08:52 +08:00
91e5db3e83 chore: Advance the timing of the dataset payment prompt (#29497)
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-11 15:49:42 +08:00
f20a2d1586 chore: add placeholder for Amplitude API key in .env.example (#29489)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-11 15:21:52 +08:00
6e802a343e perf: remove the n+1 query (#29483) 2025-12-11 15:18:27 +08:00
yyh
a30cbe3c95 test: add debug-with-multiple-model spec (#29490) 2025-12-11 15:05:37 +08:00
7344adf65e feat: add Amplitude API key to Docker entrypoint script (#29477)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-11 14:44:12 +08:00
fcadee9413 fix: flask db downgrade not work (#29465) 2025-12-11 14:30:09 +08:00
69a22af1c9 fix: optimize database query when retrieval knowledge in App (#29467) 2025-12-11 13:50:46 +08:00
aac6f44562 fix: workflow end node validate error (#29473)
Co-authored-by: Novice <novice12185727@gmail.com>
2025-12-11 13:47:37 +08:00
2e1efd62e1 revert: "fix(ops): add streaming metrics and LLM span for agent-chat traces" (#29469) 2025-12-11 12:53:37 +08:00
1847609926 fix: failed to delete model (#29456) 2025-12-11 12:05:44 +08:00
91f6d25dae fix: knowledge dataset description field validation error #29404 (#29405)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-11 11:17:08 +08:00
yyh
acdbcdb6f8 chore: update packageManager version in package.json to pnpm@10.25.0 (#29407) 2025-12-11 09:51:30 +08:00
a9627ba60a chore(deps): bump types-shapely from 2.0.0.20250404 to 2.1.0.20250917 in /api (#29441)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-11 09:49:19 +08:00
266d1c70ac fix: fix custom model credentials display as plaintext (#29425) 2025-12-11 09:48:45 +08:00
d152d63e7d chore: update remove_leading_symbols pattern, keep 【 (#29419)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-11 09:47:39 +08:00
b4afc7e435 fix: Can not blank conversation ID validation in chat payloads (#29436)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-11 09:47:10 +08:00
2d496e7e08 ci: enforce semantic pull request titles (#29438)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-12-11 09:45:55 +08:00
693877e5e4 Fix: Prevent binary content from being stored in process_data for HTTP nodes (#27532)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-11 02:52:40 +08:00
8cab3e5a1e minor fix: get_tools wrong condition (#27253)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-12-11 02:33:14 +08:00
18082752a0 fix knowledge pipeline run multimodal document failed (#29431) 2025-12-10 20:42:51 +08:00
813a734f27 chore: bump dify release to 1.11.0 (#29355) 2025-12-10 19:54:25 +08:00
94244ed8f6 fix: handle potential undefined values in query_attachment_selector across multiple components (#29429) 2025-12-10 19:30:21 +08:00
yyh
ec3a52f012 Fix immediate window open defaults and error handling (#29417)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-10 19:12:14 +08:00
ea063a1139 fix(i18n): remove unused credentialSelector translations from dataset-pipeline files (#29423) 2025-12-10 19:04:34 +08:00
784008997b fix parent-child check when child chunk is not exist (#29426) 2025-12-10 18:45:43 +08:00
0c2a354115 Using SonarJS to analyze components' complexity (#29412)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: 姜涵煦 <hanxujiang@jianghanxudeMacBook-Pro.local>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-10 17:25:54 +08:00
yyh
e477e6c928 fix: harden async window open placeholder logic (#29393) 2025-12-10 16:46:48 +08:00
bafd093fa9 fix: Add dataset file upload restrictions (#29397)
Co-authored-by: kurokobo <kuro664@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2025-12-10 16:41:05 +08:00
88b20bc6d0 fix dataset multimodal field not update (#29403) 2025-12-10 15:18:38 +08:00
12d019cd31 fix: improve compatibility of @headlessui/react with happy-dom by ensuring HTMLElement.prototype.focus is writable (#29399)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-10 14:40:48 +08:00
b49e2646ff fix: session unbound during parent-child retrieval (#29396) 2025-12-10 14:08:55 +08:00
e8720de9ad chore(i18n): translate i18n files and update type definitions (#29395)
Co-authored-by: zhsama <33454514+zhsama@users.noreply.github.com>
2025-12-10 13:52:54 +08:00
0867c1800b refactor: simplify plugin task handling and improve UI feedback (#26293) 2025-12-10 13:34:05 +08:00
681c06186e add @testing-library/user-event and create tests for external-knowledge-base/ (#29323)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-10 12:46:52 +08:00
yyh
f722fdfa6d fix: prevent popup blocker from blocking async window.open (#29391) 2025-12-10 12:46:01 +08:00
c033030d8c fix: 'list' object has no attribute 'find' (#29384)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-10 12:45:53 +08:00
51330c0ee6 fix(App.deleted_tools): incorrect compare between UUID and map with string-typed key. (#29340)
Co-authored-by: 01393547 <nieronghua@sf-express.com>
Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
2025-12-10 10:47:45 +08:00
7df360a292 fix: workflow log missing trigger icon (#29379) 2025-12-10 10:15:21 +08:00
e205182e1f fix: Parent instance <DocumentSegment at 0x7955b5572c90> is not bound… (#29377) 2025-12-10 10:01:45 +08:00
4a88c8fd19 chore: set is_multimodal db define default = false (#29362) 2025-12-10 09:44:47 +08:00
znn
1b9165624f adding llm_usage and error_type (#26546) 2025-12-10 09:19:13 +08:00
56f8bdd724 Update refactor issue template and remove tracker (#29357) 2025-12-09 22:03:21 +08:00
efa1b452da feat: Add startup parameters for language-specific Weaviate tokenizer (#29347)
Co-authored-by: Jing <jingguo92@gmail.com>
2025-12-09 21:00:19 +08:00
bcbc07e99c Add MCP backend codeowners (#29354) 2025-12-09 20:45:57 +08:00
d79d0a47a7 chore: not set empty tool config to default value (#29338) 2025-12-09 17:14:04 +08:00
f5d676f3f1 fix: agent app add tool hasn't add default params config (#29330) 2025-12-09 16:17:27 +08:00
558 changed files with 48486 additions and 3224 deletions

View File

@ -0,0 +1,205 @@
# Test Generation Checklist
Use this checklist when generating or reviewing tests for Dify frontend components.
## Pre-Generation
- [ ] Read the component source code completely
- [ ] Identify component type (component, hook, utility, page)
- [ ] Run `pnpm analyze-component <path>` if available
- [ ] Note complexity score and features detected
- [ ] Check for existing tests in the same directory
- [ ] **Identify ALL files in the directory** that need testing (not just index)
## Testing Strategy
### ⚠️ Incremental Workflow (CRITICAL for Multi-File)
- [ ] **NEVER generate all tests at once** - process one file at a time
- [ ] Order files by complexity: utilities → hooks → simple → complex → integration
- [ ] Create a todo list to track progress before starting
- [ ] For EACH file: write → run test → verify pass → then next
- [ ] **DO NOT proceed** to next file until current one passes
### Path-Level Coverage
- [ ] **Test ALL files** in the assigned directory/path
- [ ] List all components, hooks, utilities that need coverage
- [ ] Decide: single spec file (integration) or multiple spec files (unit)
### Complexity Assessment
- [ ] Run `pnpm analyze-component <path>` for complexity score
- [ ] **Complexity > 50**: Consider refactoring before testing
- [ ] **500+ lines**: Consider splitting before testing
- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure
### Integration vs Mocking
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
- [ ] Import real project components instead of mocking
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
- [ ] Prefer integration testing when using single spec file
## Required Test Sections
### All Components MUST Have
- [ ] **Rendering tests** - Component renders without crashing
- [ ] **Props tests** - Required props, optional props, default values
- [ ] **Edge cases** - null, undefined, empty values, boundaries
### Conditional Sections (Add When Feature Present)
| Feature | Add Tests For |
|---------|---------------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Code Quality Checklist
### Structure
- [ ] Uses `describe` blocks to group related tests
- [ ] Test names follow `should <behavior> when <condition>` pattern
- [ ] AAA pattern (Arrange-Act-Assert) is clear
- [ ] Comments explain complex test scenarios
### Mocks
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
### Queries
- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`)
- [ ] Use `queryBy*` for absence assertions
- [ ] Use `findBy*` for async elements
- [ ] `getByTestId` only as last resort
### Async
- [ ] All async tests use `async/await`
- [ ] `waitFor` wraps async assertions
- [ ] Fake timers properly setup/teardown
- [ ] No floating promises
### TypeScript
- [ ] No `any` types without justification
- [ ] Mock data uses actual types from source
- [ ] Factory functions have proper return types
## Coverage Goals (Per File)
For the current file being tested:
- [ ] 100% function coverage
- [ ] 100% statement coverage
- [ ] >95% branch coverage
- [ ] >95% line coverage
## Post-Generation (Per File)
**Run these checks after EACH test file, not just at the end:**
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Fix any failures immediately
- [ ] Mark file as complete in todo list
- [ ] Only then proceed to next file
### After All Files Complete
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
- [ ] Check coverage report: `pnpm test -- --coverage`
- [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check:tsgo`
## Common Issues to Watch
### False Positives
```typescript
// ❌ Mock doesn't match actual behavior
jest.mock('./Component', () => () => <div>Mocked</div>)
// ✅ Mock matches actual conditional logic
jest.mock('./Component', () => ({ isOpen }: any) =>
isOpen ? <div>Content</div> : null
)
```
### State Leakage
```typescript
// ❌ Shared state not reset
let mockState = false
jest.mock('./useHook', () => () => mockState)
// ✅ Reset in beforeEach
beforeEach(() => {
mockState = false
})
```
### Async Race Conditions
```typescript
// ❌ Not awaited
it('loads data', () => {
render(<Component />)
expect(screen.getByText('Data')).toBeInTheDocument()
})
// ✅ Properly awaited
it('loads data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### Missing Edge Cases
Always test these scenarios:
- `null` / `undefined` inputs
- Empty strings / arrays / objects
- Boundary values (0, -1, MAX_INT)
- Error states
- Loading states
- Disabled states
## Quick Commands
```bash
# Run specific test
pnpm test -- path/to/file.spec.tsx
# Run with coverage
pnpm test -- --coverage path/to/file.spec.tsx
# Watch mode
pnpm test -- --watch path/to/file.spec.tsx
# Update snapshots (use sparingly)
pnpm test -- -u path/to/file.spec.tsx
# Analyze component
pnpm analyze-component path/to/component.tsx
# Review existing test
pnpm analyze-component path/to/component.tsx --review
```

View File

@ -0,0 +1,321 @@
---
name: Dify Frontend Testing
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
---
# Dify Frontend Testing Skill
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification.
## When to Apply This Skill
Apply this skill when the user:
- Asks to **write tests** for a component, hook, or utility
- Asks to **review existing tests** for completeness
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files**
- Requests **test coverage** improvement
- Uses `pnpm analyze-component` output as context
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
- Wants to understand **testing patterns** in the Dify codebase
**Do NOT apply** when:
- User is asking about backend/API tests (Python/pytest)
- User is asking about E2E tests (Playwright/Cypress)
- User is only asking conceptual questions without code context
## Quick Reference
### Tech Stack
| Tool | Version | Purpose |
|------|---------|---------|
| Jest | 29.7 | Test runner |
| React Testing Library | 16.0 | Component testing |
| happy-dom | - | Test environment |
| nock | 14.0 | HTTP mocking |
| TypeScript | 5.x | Type safety |
### Key Commands
```bash
# Run all tests
pnpm test
# Watch mode
pnpm test -- --watch
# Run specific file
pnpm test -- path/to/file.spec.tsx
# Generate coverage report
pnpm test -- --coverage
# Analyze component complexity
pnpm analyze-component <path>
# Review existing test
pnpm analyze-component <path> --review
```
### File Naming
- Test files: `ComponentName.spec.tsx` (same directory as component)
- Integration tests: `web/__tests__/` directory
## Test Structure Template
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import Component from './index'
// ✅ Import real project components (DO NOT mock these)
// import Loading from '@/app/components/base/loading'
// import { ChildComponent } from './child-component'
// ✅ Mock external dependencies only
jest.mock('@/service/api')
jest.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }),
usePathname: () => '/test',
}))
// Shared state for mocks (if needed)
let mockSharedState = false
describe('ComponentName', () => {
beforeEach(() => {
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test
mockSharedState = false // ✅ Reset shared state
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange
const props = { title: 'Test' }
// Act
render(<Component {...props} />)
// Assert
expect(screen.getByText('Test')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should apply custom className', () => {
render(<Component className="custom" />)
expect(screen.getByRole('button')).toHaveClass('custom')
})
})
// User Interactions
describe('User Interactions', () => {
it('should handle click events', () => {
const handleClick = jest.fn()
render(<Component onClick={handleClick} />)
fireEvent.click(screen.getByRole('button'))
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle null data', () => {
render(<Component data={null} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
render(<Component items={[]} />)
expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
})
})
```
## Testing Workflow (CRITICAL)
### ⚠️ Incremental Approach Required
**NEVER generate all test files at once.** For complex components or multi-file directories:
1. **Analyze & Plan**: List all files, order by complexity (simple → complex)
1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next
1. **Verify before proceeding**: Do NOT continue to next file until current passes
```
For each file:
┌────────────────────────────────────────┐
│ 1. Write test │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 3. PASS? → Mark complete, next file │
│ FAIL? → Fix first, then continue │
└────────────────────────────────────────┘
```
### Complexity-Based Order
Process in this order for multi-file testing:
1. 🟢 Utility functions (simplest)
1. 🟢 Custom hooks
1. 🟡 Simple components (presentational)
1. 🟡 Medium components (state, effects)
1. 🔴 Complex components (API, routing)
1. 🔴 Integration tests (index files - last)
### When to Refactor First
- **Complexity > 50**: Break into smaller pieces before testing
- **500+ lines**: Consider splitting before testing
- **Many dependencies**: Extract logic into hooks first
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
## Testing Strategy
### Path-Level Testing (Directory Testing)
When assigned to test a directory/path, test **ALL content** within that path:
- Test all components, hooks, utilities in the directory (not just `index` file)
- Use incremental approach: one file at a time, verify each before proceeding
- Goal: 100% coverage of ALL files in the directory
### Integration Testing First
**Prefer integration testing** when writing tests for a directory:
-**Import real project components** directly (including base components and siblings)
-**Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
-**DO NOT mock** base components (`@/app/components/base/*`)
-**DO NOT mock** sibling/child components in the same directory
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
## Core Principles
### 1. AAA Pattern (Arrange-Act-Assert)
Every test should clearly separate:
- **Arrange**: Setup test data and render component
- **Act**: Perform user actions
- **Assert**: Verify expected outcomes
### 2. Black-Box Testing
- Test observable behavior, not implementation details
- Use semantic queries (getByRole, getByLabelText)
- Avoid testing internal state directly
- **Prefer pattern matching over hardcoded strings** in assertions:
```typescript
// ❌ Avoid: hardcoded text assertions
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Better: role-based queries
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Better: pattern matching
expect(screen.getByText(/loading/i)).toBeInTheDocument()
```
### 3. Single Behavior Per Test
Each test verifies ONE user-observable behavior:
```typescript
// ✅ Good: One behavior
it('should disable button when loading', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
})
// ❌ Bad: Multiple behaviors
it('should handle loading state', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByText('Loading...')).toBeInTheDocument()
expect(screen.getByRole('button')).toHaveClass('loading')
})
```
### 4. Semantic Naming
Use `should <behavior> when <condition>`:
```typescript
it('should show error message when validation fails')
it('should call onSubmit when form is valid')
it('should disable input when isReadOnly is true')
```
## Required Test Scenarios
### Always Required (All Components)
1. **Rendering**: Component renders without crashing
1. **Props**: Required props, optional props, default values
1. **Edge Cases**: null, undefined, empty values, boundary conditions
### Conditional (When Present)
| Feature | Test Focus |
|---------|-----------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | All onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Coverage Goals (Per File)
For each test file generated, aim for:
-**100%** function coverage
-**100%** statement coverage
-**>95%** branch coverage
-**>95%** line coverage
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
## Detailed Guides
For more detailed information, refer to:
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `guides/mocking.md` - Mock patterns and best practices
- `guides/async-testing.md` - Async operations and API calls
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
- `guides/common-patterns.md` - Frequently used testing patterns
## Authoritative References
### Primary Specification (MUST follow)
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
### Reference Examples in Codebase
- `web/utils/classnames.spec.ts` - Utility function tests
- `web/app/components/base/button/index.spec.tsx` - Component tests
- `web/__mocks__/provider-context.ts` - Mock factory example
### Project Configuration
- `web/jest.config.ts` - Jest configuration
- `web/jest.setup.ts` - Test environment setup
- `web/testing/analyze-component.js` - Component analysis tool
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations)

View File

@ -0,0 +1,345 @@
# Async Testing Guide
## Core Async Patterns
### 1. waitFor - Wait for Condition
```typescript
import { render, screen, waitFor } from '@testing-library/react'
it('should load and display data', async () => {
render(<DataComponent />)
// Wait for element to appear
await waitFor(() => {
expect(screen.getByText('Loaded Data')).toBeInTheDocument()
})
})
it('should hide loading spinner after load', async () => {
render(<DataComponent />)
// Wait for element to disappear
await waitFor(() => {
expect(screen.queryByText('Loading...')).not.toBeInTheDocument()
})
})
```
### 2. findBy\* - Async Queries
```typescript
it('should show user name after fetch', async () => {
render(<UserProfile />)
// findBy returns a promise, auto-waits up to 1000ms
const userName = await screen.findByText('John Doe')
expect(userName).toBeInTheDocument()
// findByRole with options
const button = await screen.findByRole('button', { name: /submit/i })
expect(button).toBeEnabled()
})
```
### 3. userEvent for Async Interactions
```typescript
import userEvent from '@testing-library/user-event'
it('should submit form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
render(<Form onSubmit={onSubmit} />)
// userEvent methods are async
await user.type(screen.getByLabelText('Email'), 'test@example.com')
await user.click(screen.getByRole('button', { name: /submit/i }))
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ email: 'test@example.com' })
})
})
```
## Fake Timers
### When to Use Fake Timers
- Testing components with `setTimeout`/`setInterval`
- Testing debounce/throttle behavior
- Testing animations or delayed transitions
- Testing polling or retry logic
### Basic Fake Timer Setup
```typescript
describe('Debounced Search', () => {
beforeEach(() => {
jest.useFakeTimers()
})
afterEach(() => {
jest.useRealTimers()
})
it('should debounce search input', async () => {
const onSearch = jest.fn()
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
// Type in the input
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'query' } })
// Search not called immediately
expect(onSearch).not.toHaveBeenCalled()
// Advance timers
jest.advanceTimersByTime(300)
// Now search is called
expect(onSearch).toHaveBeenCalledWith('query')
})
})
```
### Fake Timers with Async Code
```typescript
it('should retry on failure', async () => {
jest.useFakeTimers()
const fetchData = jest.fn()
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ data: 'success' })
render(<RetryComponent fetchData={fetchData} retryDelayMs={1000} />)
// First call fails
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
// Advance timer for retry
jest.advanceTimersByTime(1000)
// Second call succeeds
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(2)
expect(screen.getByText('success')).toBeInTheDocument()
})
jest.useRealTimers()
})
```
### Common Fake Timer Utilities
```typescript
// Run all pending timers
jest.runAllTimers()
// Run only pending timers (not new ones created during execution)
jest.runOnlyPendingTimers()
// Advance by specific time
jest.advanceTimersByTime(1000)
// Get current fake time
jest.now()
// Clear all timers
jest.clearAllTimers()
```
## API Testing Patterns
### Loading → Success → Error States
```typescript
describe('DataFetcher', () => {
beforeEach(() => {
jest.clearAllMocks()
})
it('should show loading state', () => {
mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) // Never resolves
render(<DataFetcher />)
expect(screen.getByTestId('loading-spinner')).toBeInTheDocument()
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1', 'Item 2'] })
render(<DataFetcher />)
// Use findBy* for multiple async elements (better error messages than waitFor with multiple assertions)
const item1 = await screen.findByText('Item 1')
const item2 = await screen.findByText('Item 2')
expect(item1).toBeInTheDocument()
expect(item2).toBeInTheDocument()
expect(screen.queryByTestId('loading-spinner')).not.toBeInTheDocument()
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Failed to fetch'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByText(/failed to fetch/i)).toBeInTheDocument()
})
})
it('should retry on error', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
fireEvent.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText('Item 1')).toBeInTheDocument()
})
})
})
```
### Testing Mutations
```typescript
it('should submit form and show success', async () => {
const user = userEvent.setup()
mockedApi.createItem.mockResolvedValue({ id: '1', name: 'New Item' })
render(<CreateItemForm />)
await user.type(screen.getByLabelText('Name'), 'New Item')
await user.click(screen.getByRole('button', { name: /create/i }))
// Button should be disabled during submission
expect(screen.getByRole('button', { name: /creating/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByText(/created successfully/i)).toBeInTheDocument()
})
expect(mockedApi.createItem).toHaveBeenCalledWith({ name: 'New Item' })
})
```
## useEffect Testing
### Testing Effect Execution
```typescript
it('should fetch data on mount', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
render(<ComponentWithEffect fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
})
```
### Testing Effect Dependencies
```typescript
it('should refetch when id changes', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('1')
})
rerender(<ComponentWithEffect id="2" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('2')
expect(fetchData).toHaveBeenCalledTimes(2)
})
})
```
### Testing Effect Cleanup
```typescript
it('should cleanup subscription on unmount', () => {
const subscribe = jest.fn()
const unsubscribe = jest.fn()
subscribe.mockReturnValue(unsubscribe)
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
expect(subscribe).toHaveBeenCalledTimes(1)
unmount()
expect(unsubscribe).toHaveBeenCalledTimes(1)
})
```
## Common Async Pitfalls
### ❌ Don't: Forget to await
```typescript
// Bad - test may pass even if assertion fails
it('should load data', () => {
render(<Component />)
waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
// Good - properly awaited
it('should load data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### ❌ Don't: Use multiple assertions in single waitFor
```typescript
// Bad - if first assertion fails, won't know about second
await waitFor(() => {
expect(screen.getByText('Title')).toBeInTheDocument()
expect(screen.getByText('Description')).toBeInTheDocument()
})
// Good - separate waitFor or use findBy
const title = await screen.findByText('Title')
const description = await screen.findByText('Description')
expect(title).toBeInTheDocument()
expect(description).toBeInTheDocument()
```
### ❌ Don't: Mix fake timers with real async
```typescript
// Bad - fake timers don't work well with real Promises
jest.useFakeTimers()
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
}) // May timeout!
// Good - use runAllTimers or advanceTimersByTime
jest.useFakeTimers()
render(<Component />)
jest.runAllTimers()
expect(screen.getByText('Data')).toBeInTheDocument()
```

View File

@ -0,0 +1,449 @@
# Common Testing Patterns
## Query Priority
Use queries in this order (most to least preferred):
```typescript
// 1. getByRole - Most recommended (accessibility)
screen.getByRole('button', { name: /submit/i })
screen.getByRole('textbox', { name: /email/i })
screen.getByRole('heading', { level: 1 })
// 2. getByLabelText - Form fields
screen.getByLabelText('Email address')
screen.getByLabelText(/password/i)
// 3. getByPlaceholderText - When no label
screen.getByPlaceholderText('Search...')
// 4. getByText - Non-interactive elements
screen.getByText('Welcome to Dify')
screen.getByText(/loading/i)
// 5. getByDisplayValue - Current input value
screen.getByDisplayValue('current value')
// 6. getByAltText - Images
screen.getByAltText('Company logo')
// 7. getByTitle - Tooltip elements
screen.getByTitle('Close')
// 8. getByTestId - Last resort only!
screen.getByTestId('custom-element')
```
## Event Handling Patterns
### Click Events
```typescript
// Basic click
fireEvent.click(screen.getByRole('button'))
// With userEvent (preferred for realistic interaction)
const user = userEvent.setup()
await user.click(screen.getByRole('button'))
// Double click
await user.dblClick(screen.getByRole('button'))
// Right click
await user.pointer({ keys: '[MouseRight]', target: screen.getByRole('button') })
```
### Form Input
```typescript
const user = userEvent.setup()
// Type in input
await user.type(screen.getByRole('textbox'), 'Hello World')
// Clear and type
await user.clear(screen.getByRole('textbox'))
await user.type(screen.getByRole('textbox'), 'New value')
// Select option
await user.selectOptions(screen.getByRole('combobox'), 'option-value')
// Check checkbox
await user.click(screen.getByRole('checkbox'))
// Upload file
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
```
### Keyboard Events
```typescript
const user = userEvent.setup()
// Press Enter
await user.keyboard('{Enter}')
// Press Escape
await user.keyboard('{Escape}')
// Keyboard shortcut
await user.keyboard('{Control>}a{/Control}') // Ctrl+A
// Tab navigation
await user.tab()
// Arrow keys
await user.keyboard('{ArrowDown}')
await user.keyboard('{ArrowUp}')
```
## Component State Testing
### Testing State Transitions
```typescript
describe('Counter', () => {
it('should increment count', async () => {
const user = userEvent.setup()
render(<Counter initialCount={0} />)
// Initial state
expect(screen.getByText('Count: 0')).toBeInTheDocument()
// Trigger transition
await user.click(screen.getByRole('button', { name: /increment/i }))
// New state
expect(screen.getByText('Count: 1')).toBeInTheDocument()
})
})
```
### Testing Controlled Components
```typescript
describe('ControlledInput', () => {
it('should call onChange with new value', async () => {
const user = userEvent.setup()
const handleChange = jest.fn()
render(<ControlledInput value="" onChange={handleChange} />)
await user.type(screen.getByRole('textbox'), 'a')
expect(handleChange).toHaveBeenCalledWith('a')
})
it('should display controlled value', () => {
render(<ControlledInput value="controlled" onChange={jest.fn()} />)
expect(screen.getByRole('textbox')).toHaveValue('controlled')
})
})
```
## Conditional Rendering Testing
```typescript
describe('ConditionalComponent', () => {
it('should show loading state', () => {
render(<DataDisplay isLoading={true} data={null} />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
expect(screen.queryByTestId('data-content')).not.toBeInTheDocument()
})
it('should show error state', () => {
render(<DataDisplay isLoading={false} data={null} error="Failed to load" />)
expect(screen.getByText(/failed to load/i)).toBeInTheDocument()
})
it('should show data when loaded', () => {
render(<DataDisplay isLoading={false} data={{ name: 'Test' }} />)
expect(screen.getByText('Test')).toBeInTheDocument()
})
it('should show empty state when no data', () => {
render(<DataDisplay isLoading={false} data={[]} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
})
```
## List Rendering Testing
```typescript
describe('ItemList', () => {
const items = [
{ id: '1', name: 'Item 1' },
{ id: '2', name: 'Item 2' },
{ id: '3', name: 'Item 3' },
]
it('should render all items', () => {
render(<ItemList items={items} />)
expect(screen.getAllByRole('listitem')).toHaveLength(3)
items.forEach(item => {
expect(screen.getByText(item.name)).toBeInTheDocument()
})
})
it('should handle item selection', async () => {
const user = userEvent.setup()
const onSelect = jest.fn()
render(<ItemList items={items} onSelect={onSelect} />)
await user.click(screen.getByText('Item 2'))
expect(onSelect).toHaveBeenCalledWith(items[1])
})
it('should handle empty list', () => {
render(<ItemList items={[]} />)
expect(screen.getByText(/no items/i)).toBeInTheDocument()
})
})
```
## Modal/Dialog Testing
```typescript
describe('Modal', () => {
it('should not render when closed', () => {
render(<Modal isOpen={false} onClose={jest.fn()} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
})
it('should render when open', () => {
render(<Modal isOpen={true} onClose={jest.fn()} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should call onClose when clicking overlay', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.click(screen.getByTestId('modal-overlay'))
expect(handleClose).toHaveBeenCalled()
})
it('should call onClose when pressing Escape', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.keyboard('{Escape}')
expect(handleClose).toHaveBeenCalled()
})
it('should trap focus inside modal', async () => {
const user = userEvent.setup()
render(
<Modal isOpen={true} onClose={jest.fn()}>
<button>First</button>
<button>Second</button>
</Modal>
)
// Focus should cycle within modal
await user.tab()
expect(screen.getByText('First')).toHaveFocus()
await user.tab()
expect(screen.getByText('Second')).toHaveFocus()
await user.tab()
expect(screen.getByText('First')).toHaveFocus() // Cycles back
})
})
```
## Form Testing
```typescript
describe('LoginForm', () => {
it('should submit valid form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(onSubmit).toHaveBeenCalledWith({
email: 'test@example.com',
password: 'password123',
})
})
it('should show validation errors', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
// Submit empty form
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/email is required/i)).toBeInTheDocument()
expect(screen.getByText(/password is required/i)).toBeInTheDocument()
})
it('should validate email format', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/invalid email/i)).toBeInTheDocument()
})
it('should disable submit button while submitting', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled()
})
})
})
```
## Data-Driven Tests with test.each
```typescript
describe('StatusBadge', () => {
test.each([
['success', 'bg-green-500'],
['warning', 'bg-yellow-500'],
['error', 'bg-red-500'],
['info', 'bg-blue-500'],
])('should apply correct class for %s status', (status, expectedClass) => {
render(<StatusBadge status={status} />)
expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass)
})
test.each([
{ input: null, expected: 'Unknown' },
{ input: undefined, expected: 'Unknown' },
{ input: '', expected: 'Unknown' },
{ input: 'invalid', expected: 'Unknown' },
])('should show "Unknown" for invalid input: $input', ({ input, expected }) => {
render(<StatusBadge status={input} />)
expect(screen.getByText(expected)).toBeInTheDocument()
})
})
```
## Debugging Tips
```typescript
// Print entire DOM
screen.debug()
// Print specific element
screen.debug(screen.getByRole('button'))
// Log testing playground URL
screen.logTestingPlaygroundURL()
// Pretty print DOM
import { prettyDOM } from '@testing-library/react'
console.log(prettyDOM(screen.getByRole('dialog')))
// Check available roles
import { getRoles } from '@testing-library/react'
console.log(getRoles(container))
```
## Common Mistakes to Avoid
### ❌ Don't Use Implementation Details
```typescript
// Bad - testing implementation
expect(component.state.isOpen).toBe(true)
expect(wrapper.find('.internal-class').length).toBe(1)
// Good - testing behavior
expect(screen.getByRole('dialog')).toBeInTheDocument()
```
### ❌ Don't Forget Cleanup
```typescript
// Bad - may leak state between tests
it('test 1', () => {
render(<Component />)
})
// Good - cleanup is automatic with RTL, but reset mocks
beforeEach(() => {
jest.clearAllMocks()
})
```
### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions)
```typescript
// ❌ Bad - hardcoded strings are brittle
expect(screen.getByText('Submit Form')).toBeInTheDocument()
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Good - role-based queries (most semantic)
expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument()
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Good - pattern matching (flexible)
expect(screen.getByText(/submit/i)).toBeInTheDocument()
expect(screen.getByText(/loading/i)).toBeInTheDocument()
// ✅ Good - test behavior, not exact UI text
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByRole('alert')).toBeInTheDocument()
```
**Why prefer black-box assertions?**
- Text content may change (i18n, copy updates)
- Role-based queries test accessibility
- Pattern matching is resilient to minor changes
- Tests focus on behavior, not implementation details
### ❌ Don't Assert on Absence Without Query
```typescript
// Bad - throws if not found
expect(screen.getByText('Error')).not.toBeInTheDocument() // Error!
// Good - use queryBy for absence assertions
expect(screen.queryByText('Error')).not.toBeInTheDocument()
```

View File

@ -0,0 +1,523 @@
# Domain-Specific Component Testing
This guide covers testing patterns for Dify's domain-specific components.
## Workflow Components (`workflow/`)
Workflow components handle node configuration, data flow, and graph operations.
### Key Test Areas
1. **Node Configuration**
1. **Data Validation**
1. **Variable Passing**
1. **Edge Connections**
1. **Error Handling**
### Example: Node Configuration Panel
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import NodeConfigPanel from './node-config-panel'
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
// Mock workflow context
jest.mock('@/app/components/workflow/hooks', () => ({
useWorkflowStore: () => mockWorkflowStore,
useNodesInteractions: () => mockNodesInteractions,
}))
let mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
}
let mockNodesInteractions = {
handleNodeSelect: jest.fn(),
handleNodeDelete: jest.fn(),
}
describe('NodeConfigPanel', () => {
beforeEach(() => {
jest.clearAllMocks()
mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
}
})
describe('Node Configuration', () => {
it('should render node type selector', () => {
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
expect(screen.getByLabelText(/model/i)).toBeInTheDocument()
})
it('should update node config on change', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4')
expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith(
node.id,
expect.objectContaining({ model: 'gpt-4' })
)
})
})
describe('Data Validation', () => {
it('should show error for invalid input', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'code' })
render(<NodeConfigPanel node={node} />)
// Enter invalid code
const codeInput = screen.getByLabelText(/code/i)
await user.clear(codeInput)
await user.type(codeInput, 'invalid syntax {{{')
await waitFor(() => {
expect(screen.getByText(/syntax error/i)).toBeInTheDocument()
})
})
it('should validate required fields', async () => {
const node = createMockNode({ type: 'http', data: { url: '' } })
render(<NodeConfigPanel node={node} />)
fireEvent.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/url is required/i)).toBeInTheDocument()
})
})
})
describe('Variable Passing', () => {
it('should display available variables from upstream nodes', () => {
const upstreamNode = createMockNode({
id: 'node-1',
type: 'start',
data: { outputs: [{ name: 'user_input', type: 'string' }] },
})
const currentNode = createMockNode({
id: 'node-2',
type: 'llm',
})
mockWorkflowStore.nodes = [upstreamNode, currentNode]
mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }]
render(<NodeConfigPanel node={currentNode} />)
// Variable selector should show upstream variables
fireEvent.click(screen.getByRole('button', { name: /add variable/i }))
expect(screen.getByText('user_input')).toBeInTheDocument()
})
it('should insert variable into prompt template', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
// Click variable button
await user.click(screen.getByRole('button', { name: /insert variable/i }))
await user.click(screen.getByText('user_input'))
const promptInput = screen.getByLabelText(/prompt/i)
expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}'))
})
})
})
```
## Dataset Components (`dataset/`)
Dataset components handle file uploads, data display, and search/filter operations.
### Key Test Areas
1. **File Upload**
1. **File Type Validation**
1. **Pagination**
1. **Search & Filtering**
1. **Data Format Handling**
### Example: Document Uploader
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DocumentUploader from './document-uploader'
jest.mock('@/service/datasets', () => ({
uploadDocument: jest.fn(),
parseDocument: jest.fn(),
}))
import * as datasetService from '@/service/datasets'
const mockedService = datasetService as jest.Mocked<typeof datasetService>
describe('DocumentUploader', () => {
beforeEach(() => {
jest.clearAllMocks()
})
describe('File Upload', () => {
it('should accept valid file types', async () => {
const user = userEvent.setup()
const onUpload = jest.fn()
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
render(<DocumentUploader onUpload={onUpload} />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
await waitFor(() => {
expect(mockedService.uploadDocument).toHaveBeenCalledWith(
expect.any(FormData)
)
})
})
it('should reject invalid file types', async () => {
const user = userEvent.setup()
render(<DocumentUploader />)
const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument()
expect(mockedService.uploadDocument).not.toHaveBeenCalled()
})
it('should show upload progress', async () => {
const user = userEvent.setup()
// Mock upload with progress
mockedService.uploadDocument.mockImplementation(() => {
return new Promise((resolve) => {
setTimeout(() => resolve({ id: 'doc-1' }), 100)
})
})
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
expect(screen.getByRole('progressbar')).toBeInTheDocument()
await waitFor(() => {
expect(screen.queryByRole('progressbar')).not.toBeInTheDocument()
})
})
})
describe('Error Handling', () => {
it('should handle upload failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed'))
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByText(/upload failed/i)).toBeInTheDocument()
})
})
it('should allow retry after failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ id: 'doc-1' })
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
await user.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument()
})
})
})
})
```
### Example: Document List with Pagination
```typescript
describe('DocumentList', () => {
describe('Pagination', () => {
it('should load first page on mount', async () => {
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 })
})
it('should navigate to next page', async () => {
const user = userEvent.setup()
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '11', name: 'Doc 11' }],
total: 50,
page: 2,
pageSize: 10,
})
await user.click(screen.getByRole('button', { name: /next/i }))
await waitFor(() => {
expect(screen.getByText('Doc 11')).toBeInTheDocument()
})
})
})
describe('Search & Filtering', () => {
it('should filter by search query', async () => {
const user = userEvent.setup()
jest.useFakeTimers()
render(<DocumentList datasetId="ds-1" />)
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
// Debounce
jest.advanceTimersByTime(300)
await waitFor(() => {
expect(mockedService.getDocuments).toHaveBeenCalledWith(
'ds-1',
expect.objectContaining({ search: 'test query' })
)
})
jest.useRealTimers()
})
})
})
```
## Configuration Components (`app/configuration/`, `config/`)
Configuration components handle forms, validation, and data persistence.
### Key Test Areas
1. **Form Validation**
1. **Save/Reset**
1. **Required vs Optional Fields**
1. **Configuration Persistence**
1. **Error Feedback**
### Example: App Configuration Form
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import AppConfigForm from './app-config-form'
jest.mock('@/service/apps', () => ({
updateAppConfig: jest.fn(),
getAppConfig: jest.fn(),
}))
import * as appService from '@/service/apps'
const mockedService = appService as jest.Mocked<typeof appService>
describe('AppConfigForm', () => {
const defaultConfig = {
name: 'My App',
description: '',
icon: 'default',
openingStatement: '',
}
beforeEach(() => {
jest.clearAllMocks()
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
})
describe('Form Validation', () => {
it('should require app name', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Clear name field
await user.clear(screen.getByLabelText(/name/i))
await user.click(screen.getByRole('button', { name: /save/i }))
expect(screen.getByText(/name is required/i)).toBeInTheDocument()
expect(mockedService.updateAppConfig).not.toHaveBeenCalled()
})
it('should validate name length', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toBeInTheDocument()
})
// Enter very long name
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101))
expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument()
})
it('should allow empty optional fields', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Leave description empty (optional)
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalled()
})
})
})
describe('Save/Reset Functionality', () => {
it('should save configuration', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Updated App')
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalledWith(
'app-1',
expect.objectContaining({ name: 'Updated App' })
)
})
expect(screen.getByText(/saved successfully/i)).toBeInTheDocument()
})
it('should reset to default values', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Changed Name')
// Reset
await user.click(screen.getByRole('button', { name: /reset/i }))
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
it('should show unsaved changes warning', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.type(screen.getByLabelText(/name/i), ' Updated')
expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument()
})
})
describe('Error Handling', () => {
it('should show error on save failure', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockRejectedValue(new Error('Server error'))
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/failed to save/i)).toBeInTheDocument()
})
})
})
})
```

View File

@ -0,0 +1,363 @@
# Mocking Guide for Dify Frontend Tests
## ⚠️ Important: What NOT to Mock
### DO NOT Mock Base Components
**Never mock components from `@/app/components/base/`** such as:
- `Loading`, `Spinner`
- `Button`, `Input`, `Select`
- `Tooltip`, `Modal`, `Dropdown`
- `Icon`, `Badge`, `Tag`
**Why?**
- Base components will have their own dedicated tests
- Mocking them creates false positives (tests pass but real integration fails)
- Using real components tests actual integration behavior
```typescript
// ❌ WRONG: Don't mock base components
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading'
import Button from '@/app/components/base/button'
// They will render normally in tests
```
### What TO Mock
Only mock these categories:
1. **API services** (`@/service/*`) - Network calls
1. **Complex context providers** - When setup is too difficult
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
1. **i18n** - Always mock to return keys
## Mock Placement
| Location | Purpose |
|----------|---------|
| `web/__mocks__/` | Reusable mocks shared across multiple test files |
| Test file | Test-specific mocks, inline with `jest.mock()` |
## Essential Mocks
### 1. i18n (Auto-loaded via Shared Mock)
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
```typescript
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
'my.custom.key': 'Custom translation',
}
return translations[key] || key
},
}),
}))
```
### 2. Next.js Router
```typescript
const mockPush = jest.fn()
const mockReplace = jest.fn()
jest.mock('next/navigation', () => ({
useRouter: () => ({
push: mockPush,
replace: mockReplace,
back: jest.fn(),
prefetch: jest.fn(),
}),
usePathname: () => '/current-path',
useSearchParams: () => new URLSearchParams('?key=value'),
}))
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
})
it('should navigate on click', () => {
render(<Component />)
fireEvent.click(screen.getByRole('button'))
expect(mockPush).toHaveBeenCalledWith('/expected-path')
})
})
```
### 3. Portal Components (with Shared State)
```typescript
// ⚠️ Important: Use shared state for components that depend on each other
let mockPortalOpenState = false
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open, ...props }: any) => {
mockPortalOpenState = open || false // Update shared state
return <div data-testid="portal" data-open={open}>{children}</div>
},
PortalToFollowElemContent: ({ children }: any) => {
// ✅ Matches actual: returns null when portal is closed
if (!mockPortalOpenState) return null
return <div data-testid="portal-content">{children}</div>
},
PortalToFollowElemTrigger: ({ children }: any) => (
<div data-testid="portal-trigger">{children}</div>
),
}))
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
mockPortalOpenState = false // ✅ Reset shared state
})
})
```
### 4. API Service Mocks
```typescript
import * as api from '@/service/api'
jest.mock('@/service/api')
const mockedApi = api as jest.Mocked<typeof api>
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
// Setup default mock implementation
mockedApi.fetchData.mockResolvedValue({ data: [] })
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
render(<Component />)
await waitFor(() => {
expect(screen.getByText('1')).toBeInTheDocument()
})
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<Component />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 5. HTTP Mocking with Nock
```typescript
import nock from 'nock'
const GITHUB_HOST = 'https://api.github.com'
const GITHUB_PATH = '/repos/owner/repo'
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
return nock(GITHUB_HOST)
.get(GITHUB_PATH)
.delay(delayMs)
.reply(status, body)
}
describe('GithubComponent', () => {
afterEach(() => {
nock.cleanAll()
})
it('should display repo info', async () => {
mockGithubApi(200, { name: 'dify', stars: 1000 })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText('dify')).toBeInTheDocument()
})
})
it('should handle API error', async () => {
mockGithubApi(500, { message: 'Server error' })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 6. Context Providers
```typescript
import { ProviderContext } from '@/context/provider-context'
import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
describe('Component with Context', () => {
it('should render for free plan', () => {
const mockContext = createMockPlan('sandbox')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.getByText('Upgrade')).toBeInTheDocument()
})
it('should render for pro plan', () => {
const mockContext = createMockPlan('professional')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
})
})
```
### 7. SWR / React Query
```typescript
// SWR
jest.mock('swr', () => ({
__esModule: true,
default: jest.fn(),
}))
import useSWR from 'swr'
const mockedUseSWR = useSWR as jest.Mock
describe('Component with SWR', () => {
it('should show loading state', () => {
mockedUseSWR.mockReturnValue({
data: undefined,
error: undefined,
isLoading: true,
})
render(<Component />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
})
// React Query
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
const createTestQueryClient = () => new QueryClient({
defaultOptions: {
queries: { retry: false },
mutations: { retry: false },
},
})
const renderWithQueryClient = (ui: React.ReactElement) => {
const queryClient = createTestQueryClient()
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>
)
}
```
## Mock Best Practices
### ✅ DO
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real project components** - Prefer importing over mocking
1. **Reset mocks in `beforeEach`**, not `afterEach`
1. **Match actual component behavior** in mocks (when mocking is necessary)
1. **Use factory functions** for complex mock data
1. **Import actual types** for type safety
1. **Reset shared mock state** in `beforeEach`
### ❌ DON'T
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
1. Don't forget to clean up nock after each test
1. Don't use `any` types in mocks without necessity
### Mock Decision Tree
```
Need to use a component in test?
├─ Is it from @/app/components/base/*?
│ └─ YES → Import real component, DO NOT mock
├─ Is it a project component?
│ └─ YES → Prefer importing real component
│ Only mock if setup is extremely complex
├─ Is it an API service (@/service/*)?
│ └─ YES → Mock it
├─ Is it a third-party lib with side effects?
│ └─ YES → Mock it (next/navigation, external SDKs)
└─ Is it i18n?
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
```
## Factory Function Pattern
```typescript
// __mocks__/data-factories.ts
import type { User, Project } from '@/types'
export const createMockUser = (overrides: Partial<User> = {}): User => ({
id: 'user-1',
name: 'Test User',
email: 'test@example.com',
role: 'member',
createdAt: new Date().toISOString(),
...overrides,
})
export const createMockProject = (overrides: Partial<Project> = {}): Project => ({
id: 'project-1',
name: 'Test Project',
description: 'A test project',
owner: createMockUser(),
members: [],
createdAt: new Date().toISOString(),
...overrides,
})
// Usage in tests
it('should display project owner', () => {
const project = createMockProject({
owner: createMockUser({ name: 'John Doe' }),
})
render(<ProjectCard project={project} />)
expect(screen.getByText('John Doe')).toBeInTheDocument()
})
```

View File

@ -0,0 +1,269 @@
# Testing Workflow Guide
This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
## Scope Clarification
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
| Scope | Rule |
|-------|------|
| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
| **Multi-file directory** | Process one file at a time, verify each before proceeding |
## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
### Why Incremental?
| Batch Approach (❌) | Incremental Approach (✅) |
|---------------------|---------------------------|
| Generate 5+ tests at once | Generate 1 test at a time |
| Run tests only at the end | Run test immediately after each file |
| Multiple failures compound | Single point of failure, easy to debug |
| Hard to identify root cause | Clear cause-effect relationship |
| Mock issues affect many files | Mock issues caught early |
| Messy git history | Clean, atomic commits possible |
## Single File Workflow
When testing a **single component, hook, or utility**:
```
1. Read source code completely
2. Run `pnpm analyze-component <path>` (if available)
3. Check complexity score and features detected
4. Write the test file
5. Run test: `pnpm test -- <file>.spec.tsx`
6. Fix any failures
7. Verify coverage meets goals (100% function, >95% branch)
```
## Directory/Multi-File Workflow (MUST FOLLOW)
When testing a **directory or multiple files**, follow this strict workflow:
### Step 1: Analyze and Plan
1. **List all files** that need tests in the directory
1. **Categorize by complexity**:
- 🟢 **Simple**: Utility functions, simple hooks, presentational components
- 🟡 **Medium**: Components with state, effects, or event handlers
- 🔴 **Complex**: Components with API calls, routing, or many dependencies
1. **Order by dependency**: Test dependencies before dependents
1. **Create a todo list** to track progress
### Step 2: Determine Processing Order
Process files in this recommended order:
```
1. Utility functions (simplest, no React)
2. Custom hooks (isolated logic)
3. Simple presentational components (few/no props)
4. Medium complexity components (state, effects)
5. Complex components (API, routing, many deps)
6. Container/index components (integration tests - last)
```
**Rationale**:
- Simpler files help establish mock patterns
- Hooks used by components should be tested first
- Integration tests (index files) depend on child components working
### Step 3: Process Each File Incrementally
**For EACH file in the ordered list:**
```
┌─────────────────────────────────────────────┐
│ 1. Write test file │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 3. If FAIL → Fix immediately, re-run │
│ 4. If PASS → Mark complete in todo list │
│ 5. ONLY THEN proceed to next file │
└─────────────────────────────────────────────┘
```
**DO NOT proceed to the next file until the current one passes.**
### Step 4: Final Verification
After all individual tests pass:
```bash
# Run all tests in the directory together
pnpm test -- path/to/directory/
# Check coverage
pnpm test -- --coverage path/to/directory/
```
## Component Complexity Guidelines
Use `pnpm analyze-component <path>` to assess complexity before testing.
### 🔴 Very Complex Components (Complexity > 50)
**Consider refactoring BEFORE testing:**
- Break component into smaller, testable pieces
- Extract complex logic into custom hooks
- Separate container and presentational layers
**If testing as-is:**
- Use integration tests for complex workflows
- Use `test.each()` for data-driven testing
- Multiple `describe` blocks for organization
- Consider testing major sections separately
### 🟡 Medium Complexity (Complexity 30-50)
- Group related tests in `describe` blocks
- Test integration scenarios between internal parts
- Focus on state transitions and side effects
- Use helper functions to reduce test complexity
### 🟢 Simple Components (Complexity < 30)
- Standard test structure
- Focus on props, rendering, and edge cases
- Usually straightforward to test
### 📏 Large Files (500+ lines)
Regardless of complexity score:
- **Strongly consider refactoring** before testing
- If testing as-is, test major sections separately
- Create helper functions for test setup
- May need multiple test files
## Todo List Format
When testing multiple files, use a todo list like this:
```
Testing: path/to/directory/
Ordered by complexity (simple → complex):
☐ utils/helper.ts [utility, simple]
☐ hooks/use-custom-hook.ts [hook, simple]
☐ empty-state.tsx [component, simple]
☐ item-card.tsx [component, medium]
☐ list.tsx [component, complex]
☐ index.tsx [integration]
Progress: 0/6 complete
```
Update status as you complete each:
- ☐ → ⏳ (in progress)
- ⏳ → ✅ (complete and verified)
- ⏳ → ❌ (blocked, needs attention)
## When to Stop and Verify
**Always run tests after:**
- Completing a test file
- Making changes to fix a failure
- Modifying shared mocks
- Updating test utilities or helpers
**Signs you should pause:**
- More than 2 consecutive test failures
- Mock-related errors appearing
- Unclear why a test is failing
- Test passing but coverage unexpectedly low
## Common Pitfalls to Avoid
### ❌ Don't: Generate Everything First
```
# BAD: Writing all files then testing
Write component-a.spec.tsx
Write component-b.spec.tsx
Write component-c.spec.tsx
Write component-d.spec.tsx
Run pnpm test ← Multiple failures, hard to debug
```
### ✅ Do: Verify Each Step
```
# GOOD: Incremental with verification
Write component-a.spec.tsx
Run pnpm test -- component-a.spec.tsx ✅
Write component-b.spec.tsx
Run pnpm test -- component-b.spec.tsx ✅
...continue...
```
### ❌ Don't: Skip Verification for "Simple" Components
Even simple components can have:
- Import errors
- Missing mock setup
- Incorrect assumptions about props
**Always verify, regardless of perceived simplicity.**
### ❌ Don't: Continue When Tests Fail
Failing tests compound:
- A mock issue in file A affects files B, C, D
- Fixing A later requires revisiting all dependent tests
- Time wasted on debugging cascading failures
**Fix failures immediately before proceeding.**
## Integration with Claude's Todo Feature
When using Claude for multi-file testing:
1. **Ask Claude to create a todo list** before starting
1. **Request one file at a time** or ensure Claude processes incrementally
1. **Verify each test passes** before asking for the next
1. **Mark todos complete** as you progress
Example prompt:
```
Test all components in `path/to/directory/`.
First, analyze the directory and create a todo list ordered by complexity.
Then, process ONE file at a time, waiting for my confirmation that tests pass
before proceeding to the next.
```
## Summary Checklist
Before starting multi-file testing:
- [ ] Listed all files needing tests
- [ ] Ordered by complexity (simple → complex)
- [ ] Created todo list for tracking
- [ ] Understand dependencies between files
During testing:
- [ ] Processing ONE file at a time
- [ ] Running tests after EACH file
- [ ] Fixing failures BEFORE proceeding
- [ ] Updating todo list progress
After completion:
- [ ] All individual tests pass
- [ ] Full directory test run passes
- [ ] Coverage goals met
- [ ] Todo list shows all complete

View File

@ -0,0 +1,296 @@
/**
* Test Template for React Components
*
* WHY THIS STRUCTURE?
* - Organized sections make tests easy to navigate and maintain
* - Mocks at top ensure consistent test isolation
* - Factory functions reduce duplication and improve readability
* - describe blocks group related scenarios for better debugging
*
* INSTRUCTIONS:
* 1. Replace `ComponentName` with your component name
* 2. Update import path
* 3. Add/remove test sections based on component features (use analyze-component)
* 4. Follow AAA pattern: Arrange → Act → Assert
*
* RUN FIRST: pnpm analyze-component <path> to identify required test scenarios
*/
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
// import ComponentName from './index'
// ============================================================================
// Mocks
// ============================================================================
// WHY: Mocks must be hoisted to top of file (Jest requirement).
// They run BEFORE imports, so keep them before component imports.
// i18n (automatically mocked)
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest
// No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required:
// jest.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
// 'my.custom.key': 'Custom Translation',
// }
// return customTranslations[key] || key
// },
// }),
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = jest.fn()
// jest.mock('next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))
// API services (if component fetches data)
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
// jest.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// Shared mock state (for portal/dropdown components)
// WHY: Portal components like PortalToFollowElem need shared state between
// parent and child mocks to correctly simulate open/close behavior
// let mockOpenState = false
// ============================================================================
// Test Data Factories
// ============================================================================
// WHY FACTORIES?
// - Avoid hard-coded test data scattered across tests
// - Easy to create variations with overrides
// - Type-safe when using actual types from source
// - Single source of truth for default test values
// const createMockProps = (overrides = {}) => ({
// // Default props that make component render successfully
// ...overrides,
// })
// const createMockItem = (overrides = {}) => ({
// id: 'item-1',
// name: 'Test Item',
// ...overrides,
// })
// ============================================================================
// Test Helpers
// ============================================================================
// const renderComponent = (props = {}) => {
// return render(<ComponentName {...createMockProps(props)} />)
// }
// ============================================================================
// Tests
// ============================================================================
describe('ComponentName', () => {
// WHY beforeEach with clearAllMocks?
// - Ensures each test starts with clean slate
// - Prevents mock call history from leaking between tests
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
beforeEach(() => {
jest.clearAllMocks()
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
// mockOpenState = false
})
// --------------------------------------------------------------------------
// Rendering Tests (REQUIRED - Every component MUST have these)
// --------------------------------------------------------------------------
// WHY: Catches import errors, missing providers, and basic render issues
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange - Setup data and mocks
// const props = createMockProps()
// Act - Render the component
// render(<ComponentName {...props} />)
// Assert - Verify expected output
// Prefer getByRole for accessibility; it's what users "see"
// expect(screen.getByRole('...')).toBeInTheDocument()
})
it('should render with default props', () => {
// WHY: Verifies component works without optional props
// render(<ComponentName />)
// expect(screen.getByText('...')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Props Tests (REQUIRED - Every component MUST test prop behavior)
// --------------------------------------------------------------------------
// WHY: Props are the component's API contract. Test them thoroughly.
describe('Props', () => {
it('should apply custom className', () => {
// WHY: Common pattern in Dify - components should merge custom classes
// render(<ComponentName className="custom-class" />)
// expect(screen.getByTestId('component')).toHaveClass('custom-class')
})
it('should use default values for optional props', () => {
// WHY: Verifies TypeScript defaults work at runtime
// render(<ComponentName />)
// expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value')
})
})
// --------------------------------------------------------------------------
// User Interactions (if component has event handlers - on*, handle*)
// --------------------------------------------------------------------------
// WHY: Event handlers are core functionality. Test from user's perspective.
describe('User Interactions', () => {
it('should call onClick when clicked', async () => {
// WHY userEvent over fireEvent?
// - userEvent simulates real user behavior (focus, hover, then click)
// - fireEvent is lower-level, doesn't trigger all browser events
// const user = userEvent.setup()
// const handleClick = jest.fn()
// render(<ComponentName onClick={handleClick} />)
//
// await user.click(screen.getByRole('button'))
//
// expect(handleClick).toHaveBeenCalledTimes(1)
})
it('should call onChange when value changes', async () => {
// const user = userEvent.setup()
// const handleChange = jest.fn()
// render(<ComponentName onChange={handleChange} />)
//
// await user.type(screen.getByRole('textbox'), 'new value')
//
// expect(handleChange).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// State Management (if component uses useState/useReducer)
// --------------------------------------------------------------------------
// WHY: Test state through observable UI changes, not internal state values
describe('State Management', () => {
it('should update state on interaction', async () => {
// WHY test via UI, not state?
// - State is implementation detail; UI is what users see
// - If UI works correctly, state must be correct
// const user = userEvent.setup()
// render(<ComponentName />)
//
// // Initial state - verify what user sees
// expect(screen.getByText('Initial')).toBeInTheDocument()
//
// // Trigger state change via user action
// await user.click(screen.getByRole('button'))
//
// // New state - verify UI updated
// expect(screen.getByText('Updated')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
// --------------------------------------------------------------------------
// WHY: Async operations have 3 states users experience: loading, success, error
describe('Async Operations', () => {
it('should show loading state', () => {
// WHY never-resolving promise?
// - Keeps component in loading state for assertion
// - Alternative: use fake timers
// mockedApi.fetchData.mockImplementation(() => new Promise(() => {}))
// render(<ComponentName />)
//
// expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
it('should show data on success', async () => {
// WHY waitFor?
// - Component updates asynchronously after fetch resolves
// - waitFor retries assertion until it passes or times out
// mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText('Item 1')).toBeInTheDocument()
// })
})
it('should show error on failure', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText(/error/i)).toBeInTheDocument()
// })
})
})
// --------------------------------------------------------------------------
// Edge Cases (REQUIRED - Every component MUST handle edge cases)
// --------------------------------------------------------------------------
// WHY: Real-world data is messy. Components must handle:
// - Null/undefined from API failures or optional fields
// - Empty arrays/strings from user clearing data
// - Boundary values (0, MAX_INT, special characters)
describe('Edge Cases', () => {
it('should handle null value', () => {
// WHY test null specifically?
// - API might return null for missing data
// - Prevents "Cannot read property of null" in production
// render(<ComponentName value={null} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle undefined value', () => {
// WHY test undefined separately from null?
// - TypeScript treats them differently
// - Optional props are undefined, not null
// render(<ComponentName value={undefined} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
// WHY: Empty state often needs special UI (e.g., "No items yet")
// render(<ComponentName items={[]} />)
// expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
it('should handle empty string', () => {
// WHY: Empty strings are truthy in JS but visually empty
// render(<ComponentName text="" />)
// expect(screen.getByText(/placeholder/i)).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Accessibility (optional but recommended for Dify's enterprise users)
// --------------------------------------------------------------------------
// WHY: Dify has enterprise customers who may require accessibility compliance
describe('Accessibility', () => {
it('should have accessible name', () => {
// WHY getByRole with name?
// - Tests that screen readers can identify the element
// - Enforces proper labeling practices
// render(<ComponentName label="Test Label" />)
// expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument()
})
it('should support keyboard navigation', async () => {
// WHY: Some users can't use a mouse
// const user = userEvent.setup()
// render(<ComponentName />)
//
// await user.tab()
// expect(screen.getByRole('button')).toHaveFocus()
})
})
})

View File

@ -0,0 +1,207 @@
/**
* Test Template for Custom Hooks
*
* Instructions:
* 1. Replace `useHookName` with your hook name
* 2. Update import path
* 3. Add/remove test sections based on hook features
*/
import { renderHook, act, waitFor } from '@testing-library/react'
// import { useHookName } from './use-hook-name'
// ============================================================================
// Mocks
// ============================================================================
// API services (if hook fetches data)
// jest.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// ============================================================================
// Test Helpers
// ============================================================================
// Wrapper for hooks that need context
// const createWrapper = (contextValue = {}) => {
// return ({ children }: { children: React.ReactNode }) => (
// <SomeContext.Provider value={contextValue}>
// {children}
// </SomeContext.Provider>
// )
// }
// ============================================================================
// Tests
// ============================================================================
describe('useHookName', () => {
beforeEach(() => {
jest.clearAllMocks()
})
// --------------------------------------------------------------------------
// Initial State
// --------------------------------------------------------------------------
describe('Initial State', () => {
it('should return initial state', () => {
// const { result } = renderHook(() => useHookName())
//
// expect(result.current.value).toBe(initialValue)
// expect(result.current.isLoading).toBe(false)
})
it('should accept initial value from props', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'custom' }))
//
// expect(result.current.value).toBe('custom')
})
})
// --------------------------------------------------------------------------
// State Updates
// --------------------------------------------------------------------------
describe('State Updates', () => {
it('should update value when setValue is called', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(result.current.value).toBe('new value')
})
it('should reset to initial value', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'initial' }))
//
// act(() => {
// result.current.setValue('changed')
// })
// expect(result.current.value).toBe('changed')
//
// act(() => {
// result.current.reset()
// })
// expect(result.current.value).toBe('initial')
})
})
// --------------------------------------------------------------------------
// Async Operations
// --------------------------------------------------------------------------
describe('Async Operations', () => {
it('should fetch data on mount', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result } = renderHook(() => useHookName())
//
// // Initially loading
// expect(result.current.isLoading).toBe(true)
//
// // Wait for data
// await waitFor(() => {
// expect(result.current.isLoading).toBe(false)
// })
//
// expect(result.current.data).toEqual({ data: 'test' })
})
it('should handle fetch error', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
//
// const { result } = renderHook(() => useHookName())
//
// await waitFor(() => {
// expect(result.current.error).toBeTruthy()
// })
//
// expect(result.current.error?.message).toBe('Network error')
})
it('should refetch when dependency changes', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result, rerender } = renderHook(
// ({ id }) => useHookName(id),
// { initialProps: { id: '1' } }
// )
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('1')
// })
//
// rerender({ id: '2' })
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('2')
// })
})
})
// --------------------------------------------------------------------------
// Side Effects
// --------------------------------------------------------------------------
describe('Side Effects', () => {
it('should call callback when value changes', () => {
// const callback = jest.fn()
// const { result } = renderHook(() => useHookName({ onChange: callback }))
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(callback).toHaveBeenCalledWith('new value')
})
it('should cleanup on unmount', () => {
// const cleanup = jest.fn()
// jest.spyOn(window, 'addEventListener')
// jest.spyOn(window, 'removeEventListener')
//
// const { unmount } = renderHook(() => useHookName())
//
// expect(window.addEventListener).toHaveBeenCalled()
//
// unmount()
//
// expect(window.removeEventListener).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle null input', () => {
// const { result } = renderHook(() => useHookName(null))
//
// expect(result.current.value).toBeNull()
})
it('should handle rapid updates', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('1')
// result.current.setValue('2')
// result.current.setValue('3')
// })
//
// expect(result.current.value).toBe('3')
})
})
// --------------------------------------------------------------------------
// With Context (if hook uses context)
// --------------------------------------------------------------------------
describe('With Context', () => {
it('should use context value', () => {
// const wrapper = createWrapper({ someValue: 'context-value' })
// const { result } = renderHook(() => useHookName(), { wrapper })
//
// expect(result.current.contextValue).toBe('context-value')
})
})
})

View File

@ -0,0 +1,154 @@
/**
* Test Template for Utility Functions
*
* Instructions:
* 1. Replace `utilityFunction` with your function name
* 2. Update import path
* 3. Use test.each for data-driven tests
*/
// import { utilityFunction } from './utility'
// ============================================================================
// Tests
// ============================================================================
describe('utilityFunction', () => {
// --------------------------------------------------------------------------
// Basic Functionality
// --------------------------------------------------------------------------
describe('Basic Functionality', () => {
it('should return expected result for valid input', () => {
// expect(utilityFunction('input')).toBe('expected-output')
})
it('should handle multiple arguments', () => {
// expect(utilityFunction('a', 'b', 'c')).toBe('abc')
})
})
// --------------------------------------------------------------------------
// Data-Driven Tests
// --------------------------------------------------------------------------
describe('Input/Output Mapping', () => {
test.each([
// [input, expected]
['input1', 'output1'],
['input2', 'output2'],
['input3', 'output3'],
])('should return %s for input %s', (input, expected) => {
// expect(utilityFunction(input)).toBe(expected)
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle empty string', () => {
// expect(utilityFunction('')).toBe('')
})
it('should handle null', () => {
// expect(utilityFunction(null)).toBe(null)
// or
// expect(() => utilityFunction(null)).toThrow()
})
it('should handle undefined', () => {
// expect(utilityFunction(undefined)).toBe(undefined)
// or
// expect(() => utilityFunction(undefined)).toThrow()
})
it('should handle empty array', () => {
// expect(utilityFunction([])).toEqual([])
})
it('should handle empty object', () => {
// expect(utilityFunction({})).toEqual({})
})
})
// --------------------------------------------------------------------------
// Boundary Conditions
// --------------------------------------------------------------------------
describe('Boundary Conditions', () => {
it('should handle minimum value', () => {
// expect(utilityFunction(0)).toBe(0)
})
it('should handle maximum value', () => {
// expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...)
})
it('should handle negative numbers', () => {
// expect(utilityFunction(-1)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Type Coercion (if applicable)
// --------------------------------------------------------------------------
describe('Type Handling', () => {
it('should handle numeric string', () => {
// expect(utilityFunction('123')).toBe(123)
})
it('should handle boolean', () => {
// expect(utilityFunction(true)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Error Cases
// --------------------------------------------------------------------------
describe('Error Handling', () => {
it('should throw for invalid input', () => {
// expect(() => utilityFunction('invalid')).toThrow('Error message')
})
it('should throw with specific error type', () => {
// expect(() => utilityFunction('invalid')).toThrow(ValidationError)
})
})
// --------------------------------------------------------------------------
// Complex Objects (if applicable)
// --------------------------------------------------------------------------
describe('Object Handling', () => {
it('should preserve object structure', () => {
// const input = { a: 1, b: 2 }
// expect(utilityFunction(input)).toEqual({ a: 1, b: 2 })
})
it('should handle nested objects', () => {
// const input = { nested: { deep: 'value' } }
// expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } })
})
it('should not mutate input', () => {
// const input = { a: 1 }
// const inputCopy = { ...input }
// utilityFunction(input)
// expect(input).toEqual(inputCopy)
})
})
// --------------------------------------------------------------------------
// Array Handling (if applicable)
// --------------------------------------------------------------------------
describe('Array Handling', () => {
it('should process all elements', () => {
// expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6])
})
it('should handle single element array', () => {
// expect(utilityFunction([1])).toEqual([2])
})
it('should preserve order', () => {
// expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b'])
})
})
})

5
.coveragerc Normal file
View File

@ -0,0 +1,5 @@
[run]
omit =
api/tests/*
api/migrations/*
api/core/rag/datasource/vdb/*

8
.github/CODEOWNERS vendored
View File

@ -9,6 +9,14 @@
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
# Backend - MCP
api/core/mcp/ @Nov1c444
api/core/entities/mcp_provider.py @Nov1c444
api/services/tools/mcp_tools_manage_service.py @Nov1c444
api/controllers/mcp/ @Nov1c444
api/controllers/console/app/mcp_server.py @Nov1c444
api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost

View File

@ -1,8 +1,6 @@
name: "✨ Refactor"
description: Refactor existing code for improved readability and maintainability.
title: "[Chore/Refactor] "
labels:
- refactor
name: "✨ Refactor or Chore"
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
title: "[Refactor/Chore] "
body:
- type: checkboxes
attributes:
@ -11,7 +9,7 @@ body:
options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
@ -25,14 +23,14 @@ body:
id: description
attributes:
label: Description
placeholder: "Describe the refactor you are proposing."
placeholder: "Describe the refactor or chore you are proposing."
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
placeholder: "Explain why this refactor is necessary."
placeholder: "Explain why this refactor or chore is necessary."
validations:
required: false
- type: textarea

View File

@ -1,13 +0,0 @@
name: "👾 Tracker"
description: For inner usages, please do not use this template.
title: "[Tracker] "
labels:
- tracker
body:
- type: textarea
id: content
attributes:
label: Blockers
placeholder: "- [ ] ..."
validations:
required: true

View File

@ -1,12 +0,0 @@
# Copilot Instructions
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
Key reminders:
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
- Target >95% line and branch coverage and 100% function/statement coverage.
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.

View File

@ -71,18 +71,18 @@ jobs:
run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh
- name: Run Tool
run: uv run --project api bash dev/pytest/pytest_tools.sh
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
- name: Run API Tests
env:
STORAGE_TYPE: opendal
OPENDAL_SCHEME: fs
OPENDAL_FS_ROOT: /tmp/dify-storage
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
uv run --project api pytest \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \
api/tests/test_containers_integration_tests \
api/tests/unit_tests
- name: Coverage Summary
run: |
@ -93,5 +93,12 @@ jobs:
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
{
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY

View File

@ -13,11 +13,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@v6
- run: |
cd api
uv sync --dev
@ -35,10 +36,11 @@ jobs:
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# ast-grep exits 1 if no matches are found; allow idempotent runs.
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
# Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union
@ -56,14 +58,15 @@ jobs:
pattern: $T
fix: $T | None
EOF
uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx mdformat .
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
- name: Install pnpm
uses: pnpm/action-setup@v4
@ -84,7 +87,6 @@ jobs:
- name: oxlint
working-directory: ./web
run: |
pnpx oxlint --fix
run: pnpm exec oxlint --config .oxlintrc.json --fix .
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

View File

@ -0,0 +1,21 @@
name: Semantic Pull Request
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
jobs:
lint:
name: Validate PR title
permissions:
pull-requests: read
runs-on: ubuntu-latest
steps:
- name: Check title
uses: amannn/action-semantic-pull-request@v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

1
.gitignore vendored
View File

@ -189,6 +189,7 @@ docker/volumes/matrixone/*
docker/volumes/mysql/*
docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d
docker/volumes/iris/*
docker/nginx/conf.d/default.conf
docker/nginx/ssl/*

1
.nvmrc Normal file
View File

@ -0,0 +1 @@
22.11.0

View File

@ -1,5 +0,0 @@
# Windsurf Testing Rules
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
- Honor every requirement in that document when generating or accepting tests.
- When proposing or saving tests, re-read that document and follow every requirement.

View File

@ -543,6 +543,25 @@ APP_MAX_EXECUTION_TIME=1200
APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Aliyun SLS Logstore Configuration
# Aliyun Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Number of days to retain workflow run logs (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both SLS LogStore and SQL database (default: false)
LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@ -660,3 +679,14 @@ SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10
# Maximum allowed CSV file size for annotation import in megabytes
ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2
#Maximum number of annotation records allowed in a single import
ANNOTATION_IMPORT_MAX_RECORDS=10000
# Minimum number of annotation records required in a single import
ANNOTATION_IMPORT_MIN_RECORDS=1
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5

View File

@ -75,6 +75,7 @@ def initialize_extensions(app: DifyApp):
ext_import_modules,
ext_logging,
ext_login,
ext_logstore,
ext_mail,
ext_migrate,
ext_orjson,
@ -83,6 +84,7 @@ def initialize_extensions(app: DifyApp):
ext_redis,
ext_request_logging,
ext_sentry,
ext_session_factory,
ext_set_secretkey,
ext_storage,
ext_timezone,
@ -104,6 +106,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
ext_logstore, # Initialize logstore after storage, before celery
ext_celery,
ext_login,
ext_mail,
@ -114,6 +117,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_otel,
ext_request_logging,
ext_session_factory,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

View File

@ -380,6 +380,37 @@ class FileUploadConfig(BaseSettings):
default=60,
)
# Annotation Import Security Configurations
ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed CSV file size for annotation import in megabytes",
default=2,
)
ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field(
description="Maximum number of annotation records allowed in a single import",
default=10000,
)
ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field(
description="Minimum number of annotation records required in a single import",
default=1,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of annotation import requests per minute per tenant",
default=5,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
description="Maximum number of annotation import requests per hour per tenant",
default=20,
)
ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field(
description="Maximum number of concurrent annotation import tasks per tenant",
default=2,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "

View File

@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.iris_config import IrisVectorConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig
@ -106,7 +107,7 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)
@ -336,6 +337,7 @@ class MiddlewareConfig(
ChromaConfig,
ClickzettaConfig,
HuaweiCloudConfig,
IrisVectorConfig,
MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig,

View File

@ -0,0 +1,91 @@
"""Configuration for InterSystems IRIS vector database."""
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
class IrisVectorConfig(BaseSettings):
"""Configuration settings for IRIS vector database connection and pooling."""
IRIS_HOST: str | None = Field(
description="Hostname or IP address of the IRIS server.",
default="localhost",
)
IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
description="Port number for IRIS connection.",
default=1972,
)
IRIS_USER: str | None = Field(
description="Username for IRIS authentication.",
default="_SYSTEM",
)
IRIS_PASSWORD: str | None = Field(
description="Password for IRIS authentication.",
default="Dify@1234",
)
IRIS_SCHEMA: str | None = Field(
description="Schema name for IRIS tables.",
default="dify",
)
IRIS_DATABASE: str | None = Field(
description="Database namespace for IRIS connection.",
default="USER",
)
IRIS_CONNECTION_URL: str | None = Field(
description="Full connection URL for IRIS (overrides individual fields if provided).",
default=None,
)
IRIS_MIN_CONNECTION: PositiveInt = Field(
description="Minimum number of connections in the pool.",
default=1,
)
IRIS_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the pool.",
default=3,
)
IRIS_TEXT_INDEX: bool = Field(
description="Enable full-text search index using %iFind.Index.Basic.",
default=True,
)
IRIS_TEXT_INDEX_LANGUAGE: str = Field(
description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
default="en",
)
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""Validate IRIS configuration values.
Args:
values: Configuration dictionary
Returns:
Validated configuration dictionary
Raises:
ValueError: If required fields are missing or pool settings are invalid
"""
# Only validate required fields if IRIS is being used as the vector store
# This allows the config to be loaded even when IRIS is not in use
# vector_store = os.environ.get("VECTOR_STORE", "")
# We rely on Pydantic defaults for required fields if they are missing from env.
# Strict existence check is removed to allow defaults to work.
min_conn = values.get("IRIS_MIN_CONNECTION", 1)
max_conn = values.get("IRIS_MAX_CONNECTION", 3)
if min_conn > max_conn:
raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
return values

View File

@ -20,6 +20,7 @@ language_timezone_mapping = {
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
"ar-TN": "Africa/Tunis",
}
languages = list(language_timezone_mapping.keys())

View File

@ -6,19 +6,20 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -90,7 +91,7 @@ class InsertExploreAppListApi(Resource):
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none()
@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource):
if not recommended_app:
return {"result": "success"}, 204
with Session(db.engine) as session:
with session_factory.create_session() as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with Session(db.engine) as session:
with session_factory.create_session() as session:
installed_apps = (
session.execute(
select(InstalledApp).where(

View File

@ -1,6 +1,6 @@
from typing import Any, Literal
from flask import request
from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
@ -8,6 +8,8 @@ from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
annotation_import_concurrency_limit,
annotation_import_rate_limit,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
@ -257,7 +259,7 @@ class AnnotationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource):
@console_ns.doc("export_annotations")
@console_ns.doc(description="Export all annotations for an app")
@console_ns.doc(description="Export all annotations for an app with CSV injection protection")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
@ -272,8 +274,14 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
return response, 200
response_data = {"data": marshal(annotation_list, annotation_fields)}
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
response.headers["Content-Type"] = "application/json; charset=utf-8"
response.headers["X-Content-Type-Options"] = "nosniff"
return response
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
@ -314,18 +322,25 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource):
@console_ns.doc("batch_import_annotations")
@console_ns.doc(description="Batch import annotations from CSV file")
@console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Batch import started successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "No file uploaded or too many files")
@console_ns.response(413, "File too large")
@console_ns.response(429, "Too many requests or concurrent imports")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@annotation_import_rate_limit
@annotation_import_concurrency_limit
@edit_permission_required
def post(self, app_id):
from configs import dify_config
app_id = str(app_id)
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -335,9 +350,27 @@ class AnnotationBatchImportApi(Resource):
# get file from request
file = request.files["file"]
# check file type
if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
# Check file size before processing
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(
413,
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
f"Please reduce the file size and try again.",
)
if file_size == 0:
raise ValueError("The uploaded file is empty")
return AppAnnotationService.batch_import_app_annotations(app_id, file)

View File

@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@console_ns.expect(console_ns.models[ParserEnable.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@ -22,7 +22,12 @@ from controllers.console.error import (
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
email_password_login_enabled,
setup_required,
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
@ -79,6 +84,7 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
@ -218,6 +224,7 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)

View File

@ -140,6 +140,18 @@ class DataSourceNotionListApi(Resource):
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
datasource_parameters = {}
if datasource_parameters_str:
try:
datasource_parameters = json.loads(datasource_parameters_str)
if not isinstance(datasource_parameters, dict):
raise ValueError("datasource_parameters must be a JSON object.")
except json.JSONDecodeError:
raise ValueError("Invalid datasource_parameters JSON format.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
@ -187,7 +199,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
datasource_parameters=datasource_parameters,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
@ -218,14 +230,14 @@ class DataSourceNotionListApi(Resource):
@console_ns.route(
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
@ -239,11 +251,10 @@ class DataSourceNotionApi(Resource):
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id)
page_id = str(page_id)
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_workspace_id="",
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),

View File

@ -223,6 +223,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
VectorType.SEEKDB,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,
@ -230,6 +231,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
VectorType.IRIS,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
@ -422,7 +424,6 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.")
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
payload_data = payload.model_dump(exclude_unset=True)
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
@ -434,6 +435,7 @@ class DatasetApi(Resource):
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
)
payload.is_multimodal = is_multimodal
payload_data = payload.model_dump(exclude_unset=True)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, payload.permission, payload.partial_member_list

View File

@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource):
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@ -4,7 +4,7 @@ from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -975,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
return recommended_plugins

View File

@ -2,7 +2,7 @@ import logging
from typing import Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
@ -52,10 +52,24 @@ class ChatMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
conversation_id: UUID | None = None
parent_message_id: UUID | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = Field(default="explore_app")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
"""
Accept blank IDs and validate UUID format when provided.
"""
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)

View File

@ -3,7 +3,7 @@ from uuid import UUID
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -30,9 +30,16 @@ class ConversationListQuery(BaseModel):
class ConversationRenamePayload(BaseModel):
name: str
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)

View File

@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 200
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource):
tenant_id=tenant_id, provider_name=provider
)
else:
model_type = args.model_type
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
)
return jsonable_encoder(

View File

@ -46,8 +46,8 @@ class PluginDebuggingKeyApi(Resource):
class ParserList(BaseModel):
page: int = Field(default=1)
page_size: int = Field(default=256)
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
reg(ParserList)
@ -106,8 +106,8 @@ class ParserPluginIdentifierQuery(BaseModel):
class ParserTasks(BaseModel):
page: int
page_size: int
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
class ParserMarketplaceUpgrade(BaseModel):

View File

@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar
from flask import abort, request
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.encryption import FieldEncryption
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo
P = ParamSpec("P")
R = TypeVar("R")
# Field names for decryption
FIELD_NAME_PASSWORD = "password"
FIELD_NAME_CODE = "code"
# Error messages for decryption failures
ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
@ -331,3 +341,163 @@ def is_admin_or_owner_required(f: Callable[P, R]):
return f(*args, **kwargs)
return decorated_function
def annotation_import_rate_limit(view: Callable[P, R]):
"""
Rate limiting decorator for annotation import operations.
Implements sliding window rate limiting with two tiers:
- Short-term: Configurable requests per minute (default: 5)
- Long-term: Configurable requests per hour (default: 20)
Uses Redis ZSET for distributed rate limiting across multiple instances.
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
# Check per-minute rate limit
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
redis_client.zadd(minute_key, {current_time: current_time})
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
minute_count = redis_client.zcard(minute_key)
redis_client.expire(minute_key, 120) # 2 minutes TTL
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
abort(
429,
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} "
f"requests per minute allowed. Please try again later.",
)
# Check per-hour rate limit
hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour"
redis_client.zadd(hour_key, {current_time: current_time})
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
hour_count = redis_client.zcard(hour_key)
redis_client.expire(hour_key, 7200) # 2 hours TTL
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
abort(
429,
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} "
f"requests per hour allowed. Please try again later.",
)
return view(*args, **kwargs)
return decorated
def annotation_import_concurrency_limit(view: Callable[P, R]):
"""
Concurrency control decorator for annotation import operations.
Limits the number of concurrent import tasks per tenant to prevent
resource exhaustion and ensure fair resource allocation.
Uses Redis ZSET to track active import jobs with automatic cleanup
of stale entries (jobs older than 2 minutes).
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
active_jobs_key = f"annotation_import_active:{current_tenant_id}"
# Clean up stale entries (jobs that should have completed or timed out)
stale_threshold = current_time - 120000 # 2 minutes ago
redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold)
# Check current active job count
active_count = redis_client.zcard(active_jobs_key)
if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT:
abort(
429,
f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} "
f"concurrent imports allowed per workspace. Please wait for existing imports to complete.",
)
# Allow the request to proceed
# The actual job registration will happen in the service layer
return view(*args, **kwargs)
return decorated
def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None:
"""
Helper to decode a Base64 encoded field in the request payload.
Args:
field_name: Name of the field to decode
error_class: Exception class to raise on decoding failure
error_message: Error message to include in the exception
"""
if not request or not request.is_json:
return
# Get the payload dict - it's cached and mutable
payload = request.get_json()
if not payload or field_name not in payload:
return
encoded_value = payload[field_name]
decoded_value = FieldEncryption.decrypt_field(encoded_value)
# If decoding failed, raise error immediately
if decoded_value is None:
raise error_class(error_message)
# Update payload dict in-place with decoded value
# Since payload is a mutable dict and get_json() returns the cached reference,
# modifying it will affect all subsequent accesses including console_ns.payload
payload[field_name] = decoded_value
def decrypt_password_field(view: Callable[P, R]):
"""
Decorator to decrypt password field in request payload.
Automatically decrypts the 'password' field if encryption is enabled.
If decryption fails, raises AuthenticationFailedError.
Usage:
@decrypt_password_field
def post(self):
args = LoginPayload.model_validate(console_ns.payload)
# args.password is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA)
return view(*args, **kwargs)
return decorated
def decrypt_code_field(view: Callable[P, R]):
"""
Decorator to decrypt verification code field in request payload.
Automatically decrypts the 'code' field if encryption is enabled.
If decryption fails, raises EmailCodeError.
Usage:
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
# args.code is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE)
return view(*args, **kwargs)
return decorated

View File

@ -4,7 +4,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
@ -52,11 +52,26 @@ class ChatRequestPayload(BaseModel):
query: str
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: UUID | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
@field_validator("conversation_id", mode="before")
@classmethod
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
if isinstance(value, str):
value = value.strip()
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)

View File

@ -4,7 +4,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@ -37,9 +37,16 @@ class ConversationListQuery(BaseModel):
class ConversationRenamePayload(BaseModel):
name: str = Field(description="New conversation name")
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")

View File

@ -33,7 +33,7 @@ def trigger_endpoint(endpoint_id: str):
if response:
break
if not response:
logger.error("Endpoint not found for {endpoint_id}")
logger.info("Endpoint not found for %s", endpoint_id)
return jsonify({"error": "Endpoint not found"}), 404
return response
except ValueError as e:

View File

@ -62,8 +62,7 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@ -73,7 +72,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow, WorkflowNodeExecutionModel
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -581,7 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
with self._database_session() as session:
# Save message
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield workflow_finish_resp
elif event.stopped_by in (
@ -591,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session:
# Save message
self._save_message(session=session, trace_manager=trace_manager)
self._save_message(session=session)
yield self._message_end_to_stream_response()
@ -600,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
@ -618,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@ -772,13 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if self._conversation_name_generate_thread:
logger.debug("Conversation name generation running as daemon thread")
def _save_message(
self,
*,
session: Session,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
):
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@ -817,14 +809,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))
# Extract model provider and model_id from workflow node executions for tracing
if message.workflow_run_id:
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
if model_info:
message.model_provider = model_info.get("provider")
message.model_id = model_info.get("model")
message_files = [
MessageFile(
message_id=message.id,
@ -842,68 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
# Trigger MESSAGE_TRACE for tracing integrations
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
)
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
"""
Extract model provider and model_id from workflow node executions.
Returns dict with 'provider' and 'model' keys, or None if not found.
"""
try:
# Query workflow node executions for LLM or Agent nodes
stmt = (
select(WorkflowNodeExecutionModel)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.limit(1)
)
node_execution = session.scalar(stmt)
if not node_execution:
return None
# Try to extract from execution_metadata for agent nodes
if node_execution.execution_metadata:
try:
metadata = json.loads(node_execution.execution_metadata)
agent_log = metadata.get("agent_log", [])
# Look for the first agent thought with provider info
for log_entry in agent_log:
entry_metadata = log_entry.get("metadata", {})
provider_str = entry_metadata.get("provider")
if provider_str:
# Parse format like "langgenius/deepseek/deepseek"
parts = provider_str.split("/")
if len(parts) >= 3:
return {"provider": parts[1], "model": parts[2]}
elif len(parts) == 2:
return {"provider": parts[0], "model": parts[1]}
except (json.JSONDecodeError, KeyError, AttributeError) as e:
logger.debug("Failed to parse execution_metadata: %s", e)
# Try to extract from process_data for llm nodes
if node_execution.process_data:
try:
process_data = json.loads(node_execution.process_data)
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider and model:
return {"provider": provider, "model": model}
except (json.JSONDecodeError, KeyError) as e:
logger.debug("Failed to parse process_data: %s", e)
return None
except Exception as e:
logger.warning("Failed to extract model info from workflow: %s", e)
return None
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@ -40,9 +40,6 @@ class EasyUITaskState(TaskState):
"""
llm_result: LLMResult
first_token_time: float | None = None
last_token_time: float | None = None
is_streaming_response: bool = False
class WorkflowTaskState(TaskState):

View File

@ -332,12 +332,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# Track streaming response times
if self._task_state.first_token_time is None:
self._task_state.first_token_time = time.perf_counter()
self._task_state.is_streaming_response = True
self._task_state.last_token_time = time.perf_counter()
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
@ -404,18 +398,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
# Add streaming metrics to usage if available
if self._task_state.is_streaming_response and self._task_state.first_token_time:
start_time = self.start_at
first_token_time = self._task_state.first_token_time
last_token_time = self._task_state.last_token_time or first_token_time
usage.time_to_first_token = round(first_token_time - start_time, 3)
usage.time_to_generate = round(last_token_time - first_token_time, 3)
# Update metadata with the complete usage info
self._task_state.metadata.usage = usage
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:

0
api/core/db/__init__.py Normal file
View File

View File

@ -0,0 +1,38 @@
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
_session_maker: sessionmaker | None = None
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
"""Configure the global session factory"""
global _session_maker
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
def get_session_maker() -> sessionmaker:
if _session_maker is None:
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
return _session_maker
def create_session() -> Session:
return get_session_maker()()
# Class wrapper for convenience
class SessionFactory:
@staticmethod
def configure(engine: Engine, expire_on_commit: bool = False):
configure_session_factory(engine, expire_on_commit)
@staticmethod
def get_session_maker() -> sessionmaker:
return get_session_maker()
@staticmethod
def create_session() -> Session:
return create_session()
session_factory = SessionFactory()

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field, field_validator
class PreviewDetail(BaseModel):
@ -20,9 +20,17 @@ class IndexingEstimate(BaseModel):
class PipelineDataset(BaseModel):
id: str
name: str
description: str
description: str = Field(default="", description="knowledge dataset description")
chunk_structure: str
@field_validator("description", mode="before")
@classmethod
def normalize_description(cls, value: str | None) -> str:
"""Coerce None to empty string so description is always a string."""
if value is None:
return ""
return value
class PipelineDocument(BaseModel):
id: str

View File

@ -213,12 +213,23 @@ class MCPProviderEntity(BaseModel):
return None
def retrieve_tokens(self) -> OAuthTokens | None:
"""OAuth tokens if available"""
"""Retrieve OAuth tokens if authentication is complete.
Returns:
OAuthTokens if the provider has been authenticated, None otherwise.
"""
if not self.credentials:
return None
credentials = self.decrypt_credentials()
access_token = credentials.get("access_token", "")
# Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header.
# Note: We don't check for whitespace-only strings here because:
# 1. OAuth servers don't return whitespace-only access tokens in practice
# 2. Even if they did, the server would return 401, triggering the OAuth flow correctly
if not access_token:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
access_token=access_token,
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),

View File

@ -0,0 +1,89 @@
"""CSV sanitization utilities to prevent formula injection attacks."""
from typing import Any
class CSVSanitizer:
"""
Sanitizer for CSV export to prevent formula injection attacks.
This class provides methods to sanitize data before CSV export by escaping
characters that could be interpreted as formulas by spreadsheet applications
(Excel, LibreOffice, Google Sheets).
Formula injection occurs when user-controlled data starting with special
characters (=, +, -, @, tab, carriage return) is exported to CSV and opened
in a spreadsheet application, potentially executing malicious commands.
"""
# Characters that can start a formula in Excel/LibreOffice/Google Sheets
FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"})
@classmethod
def sanitize_value(cls, value: Any) -> str:
"""
Sanitize a value for safe CSV export.
Prefixes formula-initiating characters with a single quote to prevent
Excel/LibreOffice/Google Sheets from treating them as formulas.
Args:
value: The value to sanitize (will be converted to string)
Returns:
Sanitized string safe for CSV export
Examples:
>>> CSVSanitizer.sanitize_value("=1+1")
"'=1+1"
>>> CSVSanitizer.sanitize_value("Hello World")
"Hello World"
>>> CSVSanitizer.sanitize_value(None)
""
"""
if value is None:
return ""
# Convert to string
str_value = str(value)
# If empty, return as is
if not str_value:
return ""
# Check if first character is a formula initiator
if str_value[0] in cls.FORMULA_CHARS:
# Prefix with single quote to escape
return f"'{str_value}"
return str_value
@classmethod
def sanitize_dict(cls, data: dict[str, Any], fields_to_sanitize: list[str] | None = None) -> dict[str, Any]:
"""
Sanitize specified fields in a dictionary.
Args:
data: Dictionary containing data to sanitize
fields_to_sanitize: List of field names to sanitize.
If None, sanitizes all string fields.
Returns:
Dictionary with sanitized values (creates a shallow copy)
Examples:
>>> data = {"question": "=1+1", "answer": "+calc", "id": "123"}
>>> CSVSanitizer.sanitize_dict(data, ["question", "answer"])
{"question": "'=1+1", "answer": "'+calc", "id": "123"}
"""
sanitized = data.copy()
if fields_to_sanitize is None:
# Sanitize all string fields
fields_to_sanitize = [k for k, v in data.items() if isinstance(v, str)]
for field in fields_to_sanitize:
if field in sanitized:
sanitized[field] = cls.sanitize_value(sanitized[field])
return sanitized

View File

@ -9,6 +9,7 @@ import httpx
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
logger = logging.getLogger(__name__)
@ -93,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
while retries <= max_retries:
try:
response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):
# Check if this is a Squid SSRF rejection
server_header = response.headers.get("server", "").lower()
via_header = response.headers.get("via", "").lower()
# Squid typically identifies itself in Server or Via headers
if "squid" in server_header or "squid" in via_header:
raise ToolSSRFError(
f"Access to '{url}' was blocked by SSRF protection. "
f"The URL may point to a private or local network address. "
)
if response.status_code not in STATUS_FORCELIST:
return response

View File

@ -72,15 +72,22 @@ class LLMGenerator:
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
if cleaned_answer is None:
if answer is None:
return ""
try:
result_dict = json.loads(cleaned_answer)
answer = result_dict["Your Output"]
result_dict = json.loads(answer)
except json.JSONDecodeError:
logger.exception("Failed to generate name after answer, use query instead")
result_dict = json_repair.loads(answer)
if not isinstance(result_dict, dict):
answer = query
else:
output = result_dict.get("Your Output")
if isinstance(output, str) and output.strip():
answer = output.strip()
else:
answer = query
name = answer.strip()
if len(name) > 75:
@ -554,11 +561,16 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_raw = cast(str, response.message.content)
generated_raw = response.message.get_text_content()
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
json_str = generated_raw[first_brace : last_brace + 1]
data = json_repair.loads(json_str)
if not isinstance(data, dict):
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
return data
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}

View File

@ -6,7 +6,13 @@ from datetime import datetime, timedelta
from typing import Any, Union, cast
from urllib.parse import urlparse
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
from openinference.semconv.trace import (
MessageAttributes,
OpenInferenceMimeTypeValues,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolCallAttributes,
)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
@ -95,14 +101,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
def datetime_to_nanos(dt: datetime | None) -> int:
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
"""Convert datetime to nanoseconds since epoch for Arize/Phoenix."""
if dt is None:
dt = datetime.now()
return int(dt.timestamp() * 1_000_000_000)
def error_to_string(error: Exception | str | None) -> str:
"""Convert an error to a string with traceback information."""
"""Convert an error to a string with traceback information for Arize/Phoenix."""
error_message = "Empty Stack Trace"
if error:
if isinstance(error, Exception):
@ -114,7 +120,7 @@ def error_to_string(error: Exception | str | None) -> str:
def set_span_status(current_span: Span, error: Exception | str | None = None):
"""Set the status of the current span based on the presence of an error."""
"""Set the status of the current span based on the presence of an error for Arize/Phoenix."""
if error:
error_string = error_to_string(error)
current_span.set_status(Status(StatusCode.ERROR, error_string))
@ -138,10 +144,17 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
def safe_json_dumps(obj: Any) -> str:
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
"""A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix."""
return json.dumps(obj, default=str, ensure_ascii=False)
def wrap_span_metadata(metadata, **kwargs):
"""Add common metatada to all trace entity types for Arize/Phoenix."""
metadata["created_from"] = "Dify"
metadata.update(kwargs)
return metadata
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
@ -183,16 +196,27 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
workflow_metadata = {
"workflow_run_id": trace_info.workflow_run_id or "",
"message_id": trace_info.message_id or "",
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
"status": trace_info.workflow_run_status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
}
workflow_metadata.update(trace_info.metadata)
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.workflow_run_status or "",
status_message=trace_info.error or "",
level="ERROR" if trace_info.error else "DEFAULT",
trace_entity_type="workflow",
conversation_id=trace_info.conversation_id or "",
workflow_app_log_id=trace_info.workflow_app_log_id or "",
workflow_id=trace_info.workflow_id or "",
tenant_id=trace_info.tenant_id or "",
workflow_run_id=trace_info.workflow_run_id or "",
workflow_run_elapsed_time=trace_info.workflow_run_elapsed_time or 0,
workflow_run_version=trace_info.workflow_run_version or "",
total_tokens=trace_info.total_tokens or 0,
file_list=safe_json_dumps(file_list),
query=trace_info.query or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
self.ensure_root_span(dify_trace_id)
@ -201,10 +225,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
@ -257,6 +283,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"app_id": app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"status_message": node_execution.error or "",
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
}
)
@ -290,11 +317,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
@ -339,30 +366,37 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def message_trace(self, trace_info: MessageTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping message trace.")
return
file_list = cast(list[str], trace_info.file_list) or []
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
message_file_data: MessageFile | None = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_metadata = {
"message_id": trace_info.message_id or "",
"conversation_mode": str(trace_info.conversation_mode or ""),
"user_id": trace_info.message_data.from_account_id or "",
"file_list": json.dumps(file_list),
"status": trace_info.message_data.status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
"prompt_tokens": trace_info.message_tokens or 0,
"completion_tokens": trace_info.answer_tokens or 0,
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
message_metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.error or "",
level="ERROR" if trace_info.error else "DEFAULT",
trace_entity_type="message",
conversation_model=trace_info.conversation_model or "",
message_tokens=trace_info.message_tokens or 0,
answer_tokens=trace_info.answer_tokens or 0,
total_tokens=trace_info.total_tokens or 0,
conversation_mode=trace_info.conversation_mode or "",
gen_ai_server_time_to_first_token=trace_info.gen_ai_server_time_to_first_token or 0,
llm_streaming_time_to_generate=trace_info.llm_streaming_time_to_generate or 0,
is_streaming_request=trace_info.is_streaming_request or False,
user_id=trace_info.message_data.from_account_id or "",
file_list=safe_json_dumps(file_list),
model_provider=trace_info.message_data.model_provider or "",
model_id=trace_info.message_data.model_id or "",
)
# Add end user data if available
if trace_info.message_data.from_end_user_id:
@ -370,14 +404,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id
metadata["end_user_id"] = end_user_data.session_id
attributes = {
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
}
dify_trace_id = trace_info.trace_id or trace_info.message_id
@ -393,8 +429,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
try:
# Convert outputs to string based on type
outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
if isinstance(trace_info.outputs, dict | list):
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
outputs_str = safe_json_dumps(trace_info.outputs)
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
elif isinstance(trace_info.outputs, str):
outputs_str = trace_info.outputs
else:
@ -402,10 +440,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: outputs_str,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
}
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
@ -449,16 +489,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping moderation trace.")
return
metadata = {
"message_id": trace_info.message_id,
"tool_name": "moderation",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.message_data.error or "",
level="ERROR" if trace_info.message_data.error else "DEFAULT",
trace_entity_type="moderation",
model_provider=trace_info.message_data.model_provider or "",
model_id=trace_info.message_data.model_id or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
@ -467,18 +511,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(
{
"action": trace_info.action,
"flagged": trace_info.flagged,
"action": trace_info.action,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
ensure_ascii=False,
"query": trace_info.query,
}
),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
@ -494,22 +539,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping suggested question trace.")
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "suggested_question",
"status": trace_info.status,
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens,
"ls_provider": trace_info.model_provider or "",
"ls_model_name": trace_info.model_id or "",
}
metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.status or "",
status_message=trace_info.status_message or "",
level=trace_info.level or "",
trace_entity_type="suggested_question",
total_tokens=trace_info.total_tokens or 0,
from_account_id=trace_info.from_account_id or "",
agent_based=trace_info.agent_based or False,
from_source=trace_info.from_source or "",
model_provider=trace_info.model_provider or "",
model_id=trace_info.model_id or "",
workflow_run_id=trace_info.workflow_run_id or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
@ -518,10 +569,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.suggested_question),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
@ -537,21 +590,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping dataset retrieval trace.")
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "dataset_retrieval",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.error or "",
level="ERROR" if trace_info.error else "DEFAULT",
trace_entity_type="dataset_retrieval",
model_provider=trace_info.message_data.model_provider or "",
model_id=trace_info.message_data.model_id or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
@ -560,20 +615,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
"start_time": start_time.isoformat() if start_time else "",
"end_time": end_time.isoformat() if end_time else "",
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"documents": trace_info.documents}),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
)
try:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
if trace_info.error:
set_span_status(span, trace_info.error)
else:
set_span_status(span)
finally:
@ -584,30 +639,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
return
metadata = {
"message_id": trace_info.message_id,
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
}
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.error or "",
level="ERROR" if trace_info.error else "DEFAULT",
trace_entity_type="tool",
tool_config=safe_json_dumps(trace_info.tool_config),
time_cost=trace_info.time_cost or 0,
file_url=trace_info.file_url or "",
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
tool_params_str = (
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
if isinstance(trace_info.tool_parameters, dict)
else str(trace_info.tool_parameters)
)
span = self.tracer.start_span(
name=trace_info.tool_name,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.tool_inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.TOOL_NAME: trace_info.tool_name,
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
SpanAttributes.TOOL_PARAMETERS: safe_json_dumps(trace_info.tool_parameters),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
@ -623,16 +682,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping generate name trace.")
return
metadata = {
"project_name": self.project,
"message_id": trace_info.message_id,
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
metadata = wrap_span_metadata(
trace_info.metadata,
trace_id=trace_info.trace_id or "",
message_id=trace_info.message_id or "",
status=trace_info.message_data.status or "",
status_message=trace_info.message_data.error or "",
level="ERROR" if trace_info.message_data.error else "DEFAULT",
trace_entity_type="generate_name",
model_provider=trace_info.message_data.model_provider or "",
model_id=trace_info.message_data.model_id or "",
conversation_id=trace_info.conversation_id or "",
tenant_id=trace_info.tenant_id,
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
self.ensure_root_span(dify_trace_id)
@ -641,13 +706,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.outputs),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
@ -688,32 +753,85 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
def get_project_url(self):
"""Build a redirect URL that forwards the user to the correct project for Arize/Phoenix."""
try:
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
return "https://app.arize.com/"
else:
return f"{self.arize_phoenix_config.endpoint}/projects/"
project_name = self.arize_phoenix_config.project
endpoint = self.arize_phoenix_config.endpoint.rstrip("/")
# Arize
if isinstance(self.arize_phoenix_config, ArizeConfig):
return f"https://app.arize.com/?redirect_project_name={project_name}"
# Phoenix
return f"{endpoint}/projects/?redirect_project_name={project_name}"
except Exception as e:
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
"""Helper method to construct LLM attributes with passed prompts."""
attributes = {}
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
attributes: dict[str, str] = {}
def set_attribute(path: str, value: object) -> None:
"""Store an attribute safely as a string."""
if value is None:
return
try:
if isinstance(value, (dict, list)):
value = safe_json_dumps(value)
attributes[path] = str(value)
except Exception:
attributes[path] = str(value)
def set_message_attribute(message_index: int, key: str, value: object) -> None:
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
set_attribute(path, value)
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
"""Extract and assign tool call details safely."""
if not tool_call:
return
def safe_get(obj, key, default=None):
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
function_obj = safe_get(tool_call, "function", {})
function_name = safe_get(function_obj, "name", "")
function_args = safe_get(function_obj, "arguments", {})
call_id = safe_get(tool_call, "id", "")
base_path = (
f"{SpanAttributes.LLM_INPUT_MESSAGES}."
f"{message_index}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_index}"
)
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", function_name)
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", function_args)
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
# Handle list of messages
if isinstance(prompts, list):
for i, msg in enumerate(prompts):
if isinstance(msg, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(prompts, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(prompts, str):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
for message_index, message in enumerate(prompts):
if not isinstance(message, dict):
continue
role = message.get("role", "user")
content = message.get("text") or message.get("content") or ""
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
tool_calls = message.get("tool_calls") or []
if isinstance(tool_calls, list):
for tool_index, tool_call in enumerate(tool_calls):
set_tool_call_attributes(message_index, tool_index, tool_call)
# Handle single dict or plain string prompt
elif isinstance(prompts, (dict, str)):
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
return attributes

View File

@ -222,59 +222,6 @@ class TencentSpanBuilder:
links=links,
)
@staticmethod
def build_message_llm_span(
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
) -> SpanData:
"""Build LLM span for message traces with detailed LLM attributes."""
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
# Extract model information from `metadata`` or `message_data`
trace_metadata = trace_info.metadata or {}
message_data = trace_info.message_data or {}
model_provider = trace_metadata.get("ls_provider") or (
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
)
model_name = trace_metadata.get("ls_model_name") or (
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
)
inputs_str = str(trace_info.inputs or "")
outputs_str = str(trace_info.outputs or "")
attributes = {
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: str(model_name),
GEN_AI_PROVIDER: str(model_provider),
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
GEN_AI_PROMPT: inputs_str,
GEN_AI_COMPLETION: outputs_str,
INPUT_VALUE: inputs_str,
OUTPUT_VALUE: outputs_str,
}
if trace_info.is_streaming_request:
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
)
@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""

View File

@ -107,12 +107,8 @@ class TencentDataTrace(BaseTraceInstance):
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
self.trace_client.add_span(message_span)
# Add LLM child span with detailed attributes
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
self.trace_client.add_span(llm_span)
self.trace_client.add_span(message_span)
self._record_message_llm_metrics(trace_info)

View File

@ -371,7 +371,7 @@ class RetrievalService:
include_segment_ids = set()
segment_child_map = {}
segment_file_map = {}
with Session(db.engine) as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
# Process documents
for document in documents:
segment_id = None
@ -395,7 +395,7 @@ class RetrievalService:
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attchment_info"]
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
@ -417,13 +417,6 @@ class RetrievalService:
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
)
)
.first()
)
@ -458,12 +451,21 @@ class RetrievalService:
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
if segment.id in segment_child_map:
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
segment_child_map[segment.id] = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
if attachment_info:
segment_file_map[segment.id].append(attachment_info)
if segment.id in segment_file_map:
segment_file_map[segment.id].append(attachment_info)
else:
segment_file_map[segment.id] = [attachment_info]
else:
# Handle normal documents
segment = None
@ -475,7 +477,7 @@ class RetrievalService:
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attchment_info"]
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
@ -483,7 +485,7 @@ class RetrievalService:
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = db.session.scalar(document_segment_stmt)
segment = session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
@ -496,7 +498,7 @@ class RetrievalService:
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
segment = session.scalar(document_segment_stmt)
if not segment:
continue
@ -684,7 +686,7 @@ class RetrievalService:
.first()
)
if attachment_binding:
attchment_info = {
attachment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -692,5 +694,5 @@ class RetrievalService:
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id}
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
return None

View File

@ -0,0 +1,407 @@
"""InterSystems IRIS vector database implementation for Dify.
This module provides vector storage and retrieval using IRIS native VECTOR type
with HNSW indexing for efficient similarity search.
"""
from __future__ import annotations
import json
import logging
import threading
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from configs import dify_config
from configs.middleware.vdb.iris_config import IrisVectorConfig
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
if TYPE_CHECKING:
import iris
else:
try:
import iris
except ImportError:
iris = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
# Singleton connection pool to minimize IRIS license usage
_pool_lock = threading.Lock()
_pool_instance: IrisConnectionPool | None = None
def get_iris_pool(config: IrisVectorConfig) -> IrisConnectionPool:
"""Get or create the global IRIS connection pool (singleton pattern)."""
global _pool_instance # pylint: disable=global-statement
with _pool_lock:
if _pool_instance is None:
logger.info("Initializing IRIS connection pool")
_pool_instance = IrisConnectionPool(config)
return _pool_instance
class IrisConnectionPool:
"""Thread-safe connection pool for IRIS database."""
def __init__(self, config: IrisVectorConfig) -> None:
self.config = config
self._pool: list[Any] = []
self._lock = threading.Lock()
self._min_size = config.IRIS_MIN_CONNECTION
self._max_size = config.IRIS_MAX_CONNECTION
self._in_use = 0
self._schemas_initialized: set[str] = set() # Cache for initialized schemas
self._initialize_pool()
def _initialize_pool(self) -> None:
for _ in range(self._min_size):
self._pool.append(self._create_connection())
def _create_connection(self) -> Any:
return iris.connect(
hostname=self.config.IRIS_HOST,
port=self.config.IRIS_SUPER_SERVER_PORT,
namespace=self.config.IRIS_DATABASE,
username=self.config.IRIS_USER,
password=self.config.IRIS_PASSWORD,
)
def get_connection(self) -> Any:
"""Get a connection from pool or create new if available."""
with self._lock:
if self._pool:
conn = self._pool.pop()
self._in_use += 1
return conn
if self._in_use < self._max_size:
conn = self._create_connection()
self._in_use += 1
return conn
raise RuntimeError("Connection pool exhausted")
def return_connection(self, conn: Any) -> None:
"""Return connection to pool after validating it."""
if not conn:
return
# Validate connection health
is_valid = False
try:
cursor = conn.cursor()
cursor.execute("SELECT 1")
cursor.close()
is_valid = True
except (OSError, RuntimeError) as e:
logger.debug("Connection validation failed: %s", e)
try:
conn.close()
except (OSError, RuntimeError):
pass
with self._lock:
self._pool.append(conn if is_valid else self._create_connection())
self._in_use -= 1
def ensure_schema_exists(self, schema: str) -> None:
"""Ensure schema exists in IRIS database.
This method is idempotent and thread-safe. It uses a memory cache to avoid
redundant database queries for already-verified schemas.
Args:
schema: Schema name to ensure exists
Raises:
Exception: If schema creation fails
"""
# Fast path: check cache first (no lock needed for read-only set lookup)
if schema in self._schemas_initialized:
return
# Slow path: acquire lock and check again (double-checked locking)
with self._lock:
if schema in self._schemas_initialized:
return
# Get a connection to check/create schema
conn = self._pool[0] if self._pool else self._create_connection()
cursor = conn.cursor()
try:
# Check if schema exists using INFORMATION_SCHEMA
check_sql = """
SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA
WHERE SCHEMA_NAME = ?
"""
cursor.execute(check_sql, (schema,)) # Must be tuple or list
exists = cursor.fetchone()[0] > 0
if not exists:
# Schema doesn't exist, create it
cursor.execute(f"CREATE SCHEMA {schema}")
conn.commit()
logger.info("Created schema: %s", schema)
else:
logger.debug("Schema already exists: %s", schema)
# Add to cache to skip future checks
self._schemas_initialized.add(schema)
except Exception as e:
conn.rollback()
logger.exception("Failed to ensure schema %s exists", schema)
raise
finally:
cursor.close()
def close_all(self) -> None:
"""Close all connections (application shutdown only)."""
with self._lock:
for conn in self._pool:
try:
conn.close()
except (OSError, RuntimeError):
pass
self._pool.clear()
self._in_use = 0
self._schemas_initialized.clear()
class IrisVector(BaseVector):
"""IRIS vector database implementation using native VECTOR type and HNSW indexing."""
def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
super().__init__(collection_name)
self.config = config
self.table_name = f"embedding_{collection_name}".upper()
self.schema = config.IRIS_SCHEMA or "dify"
self.pool = get_iris_pool(config)
def get_type(self) -> str:
return VectorType.IRIS
@contextmanager
def _get_cursor(self):
"""Context manager for database cursor with connection pooling."""
conn = self.pool.get_connection()
cursor = conn.cursor()
try:
yield cursor
conn.commit()
except Exception:
conn.rollback()
raise
finally:
cursor.close()
self.pool.return_connection(conn)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]:
"""Add documents with embeddings to the collection."""
added_ids = []
with self._get_cursor() as cursor:
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) if doc.metadata else str(uuid.uuid4())
metadata = json.dumps(doc.metadata) if doc.metadata else "{}"
embedding_str = json.dumps(embeddings[i])
sql = f"INSERT INTO {self.schema}.{self.table_name} (id, text, meta, embedding) VALUES (?, ?, ?, ?)"
cursor.execute(sql, (doc_id, doc.page_content, metadata, embedding_str))
added_ids.append(doc_id)
return added_ids
def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin
try:
with self._get_cursor() as cursor:
sql = f"SELECT 1 FROM {self.schema}.{self.table_name} WHERE id = ?"
cursor.execute(sql, (id,))
return cursor.fetchone() is not None
except (OSError, RuntimeError, ValueError):
return False
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cursor:
placeholders = ",".join(["?" for _ in ids])
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})"
cursor.execute(sql, ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Delete documents by metadata field (JSON LIKE pattern matching)."""
with self._get_cursor() as cursor:
pattern = f'%"{key}": "{value}"%'
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?"
cursor.execute(sql, (pattern,))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search similar documents using VECTOR_COSINE with HNSW index."""
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
embedding_str = json.dumps(query_vector)
with self._get_cursor() as cursor:
sql = f"""
SELECT TOP {top_k} id, text, meta, VECTOR_COSINE(embedding, ?) as score
FROM {self.schema}.{self.table_name}
ORDER BY score DESC
"""
cursor.execute(sql, (embedding_str,))
docs = []
for row in cursor.fetchall():
if len(row) >= 4:
text, meta_str, score = row[1], row[2], float(row[3])
if score >= score_threshold:
metadata = json.loads(meta_str) if meta_str else {}
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search documents by full-text using iFind index or fallback to LIKE search."""
top_k = kwargs.get("top_k", 5)
with self._get_cursor() as cursor:
if self.config.IRIS_TEXT_INDEX:
# Use iFind full-text search with index
text_index_name = f"idx_{self.table_name}_text"
sql = f"""
SELECT TOP {top_k} id, text, meta
FROM {self.schema}.{self.table_name}
WHERE %ID %FIND search_index({text_index_name}, ?)
"""
cursor.execute(sql, (query,))
else:
# Fallback to LIKE search (inefficient for large datasets)
query_pattern = f"%{query}%"
sql = f"""
SELECT TOP {top_k} id, text, meta
FROM {self.schema}.{self.table_name}
WHERE text LIKE ?
"""
cursor.execute(sql, (query_pattern,))
docs = []
for row in cursor.fetchall():
if len(row) >= 3:
metadata = json.loads(row[2]) if row[2] else {}
docs.append(Document(page_content=row[1], metadata=metadata))
if not docs:
logger.info("Full-text search for '%s' returned no results", query)
return docs
def delete(self) -> None:
"""Delete the entire collection (drop table - permanent)."""
with self._get_cursor() as cursor:
sql = f"DROP TABLE {self.schema}.{self.table_name}"
cursor.execute(sql)
def _create_collection(self, dimension: int) -> None:
"""Create table with VECTOR column and HNSW index.
Uses Redis lock to prevent concurrent creation attempts across multiple
API server instances (api, worker, worker_beat).
"""
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20): # pylint: disable=not-context-manager
if redis_client.get(cache_key):
return
# Ensure schema exists (idempotent, cached after first call)
self.pool.ensure_schema_exists(self.schema)
with self._get_cursor() as cursor:
# Create table with VECTOR column
sql = f"""
CREATE TABLE {self.schema}.{self.table_name} (
id VARCHAR(255) PRIMARY KEY,
text CLOB,
meta CLOB,
embedding VECTOR(DOUBLE, {dimension})
)
"""
logger.info("Creating table: %s.%s", self.schema, self.table_name)
cursor.execute(sql)
# Create HNSW index for vector similarity search
index_name = f"idx_{self.table_name}_embedding"
sql_index = (
f"CREATE INDEX {index_name} ON {self.schema}.{self.table_name} "
"(embedding) AS HNSW(Distance='Cosine')"
)
logger.info("Creating HNSW index: %s", index_name)
cursor.execute(sql_index)
logger.info("HNSW index created successfully: %s", index_name)
# Create full-text search index if enabled
logger.info(
"IRIS_TEXT_INDEX config value: %s (type: %s)",
self.config.IRIS_TEXT_INDEX,
type(self.config.IRIS_TEXT_INDEX),
)
if self.config.IRIS_TEXT_INDEX:
text_index_name = f"idx_{self.table_name}_text"
language = self.config.IRIS_TEXT_INDEX_LANGUAGE
# Fixed: Removed extra parentheses and corrected syntax
sql_text_index = f"""
CREATE INDEX {text_index_name} ON {self.schema}.{self.table_name} (text)
AS %iFind.Index.Basic
(LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
"""
logger.info("Creating text index: %s with language: %s", text_index_name, language)
logger.info("SQL for text index: %s", sql_text_index)
cursor.execute(sql_text_index)
logger.info("Text index created successfully: %s", text_index_name)
else:
logger.warning("Text index creation skipped - IRIS_TEXT_INDEX is disabled")
redis_client.set(cache_key, 1, ex=3600)
class IrisVectorFactory(AbstractVectorFactory):
"""Factory for creating IrisVector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = self.gen_index_struct_dict(VectorType.IRIS, collection_name)
dataset.index_struct = json.dumps(index_struct_dict)
return IrisVector(
collection_name=collection_name,
config=IrisVectorConfig(
IRIS_HOST=dify_config.IRIS_HOST,
IRIS_SUPER_SERVER_PORT=dify_config.IRIS_SUPER_SERVER_PORT,
IRIS_USER=dify_config.IRIS_USER,
IRIS_PASSWORD=dify_config.IRIS_PASSWORD,
IRIS_DATABASE=dify_config.IRIS_DATABASE,
IRIS_SCHEMA=dify_config.IRIS_SCHEMA,
IRIS_CONNECTION_URL=dify_config.IRIS_CONNECTION_URL,
IRIS_MIN_CONNECTION=dify_config.IRIS_MIN_CONNECTION,
IRIS_MAX_CONNECTION=dify_config.IRIS_MAX_CONNECTION,
IRIS_TEXT_INDEX=dify_config.IRIS_TEXT_INDEX,
IRIS_TEXT_INDEX_LANGUAGE=dify_config.IRIS_TEXT_INDEX_LANGUAGE,
),
)

View File

@ -163,7 +163,7 @@ class Vector:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE:
case VectorType.OCEANBASE | VectorType.SEEKDB:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
@ -187,6 +187,10 @@ class Vector:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case VectorType.IRIS:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -27,8 +27,10 @@ class VectorType(StrEnum):
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"
SEEKDB = "seekdb"
OPENGAUSS = "opengauss"
TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud"
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"
IRIS = "iris"

View File

@ -10,7 +10,7 @@ class NotionInfo(BaseModel):
"""
credential_id: str | None = None
notion_workspace_id: str
notion_workspace_id: str | None = ""
notion_obj_id: str
notion_page_type: str
document: Document | None = None

View File

@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
import os
from typing import cast
from typing import TypedDict
import pandas as pd
from openpyxl import load_workbook
@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class Candidate(TypedDict):
idx: int
count: int
map: dict[int, str]
class ExcelExtractor(BaseExtractor):
"""Load Excel files.
@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor):
file_extension = os.path.splitext(self._file_path)[-1].lower()
if file_extension == ".xlsx":
wb = load_workbook(self._file_path, data_only=True)
for sheet_name in wb.sheetnames:
sheet = wb[sheet_name]
data = sheet.values
cols = next(data, None)
if cols is None:
continue
df = pd.DataFrame(data, columns=cols)
df.dropna(how="all", inplace=True)
for index, row in df.iterrows():
page_content = []
for col_index, (k, v) in enumerate(row.items()):
if pd.notna(v):
cell = sheet.cell(
row=cast(int, index) + 2, column=col_index + 1
) # +2 to account for header and 1-based index
if cell.hyperlink:
value = f"[{v}]({cell.hyperlink.target})"
page_content.append(f'"{k}":"{value}"')
else:
page_content.append(f'"{k}":"{v}"')
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)
wb = load_workbook(self._file_path, read_only=True, data_only=True)
try:
for sheet_name in wb.sheetnames:
sheet = wb[sheet_name]
header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet)
if not column_map:
continue
start_row = header_row_idx + 1
for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
if all(cell.value is None for cell in row):
continue
page_content = []
for col_idx, cell in enumerate(row):
value = cell.value
if col_idx in column_map:
col_name = column_map[col_idx]
if hasattr(cell, "hyperlink") and cell.hyperlink:
target = getattr(cell.hyperlink, "target", None)
if target:
value = f"[{value}]({target})"
if value is None:
value = ""
elif not isinstance(value, str):
value = str(value)
value = value.strip().replace('"', '\\"')
page_content.append(f'"{col_name}":"{value}"')
if page_content:
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)
finally:
wb.close()
elif file_extension == ".xls":
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor):
df = excel_file.parse(sheet_name=excel_sheet_name)
df.dropna(how="all", inplace=True)
for _, row in df.iterrows():
for _, series_row in df.iterrows():
page_content = []
for k, v in row.items():
for k, v in series_row.items():
if pd.notna(v):
page_content.append(f'"{k}":"{v}"')
documents.append(
@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor):
raise ValueError(f"Unsupported file extension: {file_extension}")
return documents
def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
"""
Scan first N rows to find the most likely header row.
Returns:
header_row_idx: 1-based index of the header row
column_map: Dict mapping 0-based column index to column name
max_col_idx: 1-based index of the last valid column (for iter_rows boundary)
"""
# Store potential candidates: (row_index, non_empty_count, column_map)
candidates: list[Candidate] = []
# Limit scan to avoid performance issues on huge files
# We iterate manually to control the read scope
for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1):
# Filter out empty cells and build a temp map for this row
# col_idx is 0-based
row_map = {}
for col_idx, cell_value in enumerate(row):
if cell_value is not None and str(cell_value).strip():
row_map[col_idx] = str(cell_value).strip().replace('"', '\\"')
if not row_map:
continue
non_empty_count = len(row_map)
# Header selection heuristic (implemented):
# - Prefer the first row with at least 2 non-empty columns.
# - Fallback: choose the row with the most non-empty columns
# (tie-breaker: smaller row index).
candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map})
if not candidates:
return 0, {}, 0
# Choose the best candidate header row.
best_candidate: Candidate | None = None
# Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback.
for cand in candidates:
if cand["count"] >= 2:
best_candidate = cand
break
# Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns
if not best_candidate:
# Sort by count desc, then index asc
candidates.sort(key=lambda x: (-x["count"], x["idx"]))
best_candidate = candidates[0]
# Determine max_col_idx (1-based for openpyxl)
# It is the index of the last valid column in our map + 1
max_col_idx = max(best_candidate["map"].keys()) + 1
return best_candidate["idx"], best_candidate["map"], max_col_idx

View File

@ -166,7 +166,7 @@ class ExtractProcessor:
elif extract_setting.datasource_type == DatasourceType.NOTION:
assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_workspace_id=extract_setting.notion_info.notion_workspace_id or "",
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document,

View File

@ -45,6 +45,6 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
except concurrent.futures.TimeoutError:
raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}")
if all(encoding["encoding"] is None for encoding in encodings):
if all(encoding.encoding is None for encoding in encodings):
raise RuntimeError(f"Could not detect encoding for {file_path}")
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
return [enc for enc in encodings if enc.encoding is not None]

View File

@ -84,22 +84,45 @@ class WordExtractor(BaseExtractor):
image_count = 0
image_map = {}
for rel in doc.part.rels.values():
for r_id, rel in doc.part.rels.items():
if "image" in rel.target_ref:
image_count += 1
if rel.is_external:
url = rel.target_ref
response = ssrf_proxy.get(url)
if not self._is_valid_url(url):
continue
try:
response = ssrf_proxy.get(url)
except Exception as e:
logger.warning("Failed to download image from URL: %s: %s", url, str(e))
continue
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", ""))
if image_ext is None:
continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, response.content)
else:
continue
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use r_id as key for external images since target_part is undefined
image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
else:
image_ext = rel.target_ref.split(".")[-1]
if image_ext is None:
@ -110,27 +133,28 @@ class WordExtractor(BaseExtractor):
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, rel.target_part.blob)
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
db.session.commit()
image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use target_part as key for internal images
image_map[rel.target_part] = (
f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
)
db.session.commit()
return image_map
def _table_to_markdown(self, table, image_map):
@ -186,11 +210,17 @@ class WordExtractor(BaseExtractor):
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if not image_id:
continue
image_part = paragraph.part.rels[image_id].target_part
if image_part in image_map:
image_link = image_map[image_part]
paragraph_content.append(image_link)
rel = paragraph.part.rels.get(image_id)
if rel is None:
continue
# For external images, use image_id as key; for internal, use target_part
if rel.is_external:
if image_id in image_map:
paragraph_content.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map:
paragraph_content.append(image_map[image_part])
else:
paragraph_content.append(run.text)
return "".join(paragraph_content).strip()
@ -227,6 +257,18 @@ class WordExtractor(BaseExtractor):
def parse_paragraph(paragraph):
paragraph_content = []
def append_image_link(image_id, has_drawing):
"""Helper to append image link from image_map based on relationship type."""
rel = doc.part.rels[image_id]
if rel.is_external:
if image_id in image_map and not has_drawing:
paragraph_content.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
for run in paragraph.runs:
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
@ -243,10 +285,18 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
)
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
rel = doc.part.rels.get(embed_id)
if rel is not None and rel.is_external:
# External image: use embed_id as key
if embed_id in image_map:
has_drawing = True
paragraph_content.append(image_map[embed_id])
else:
# Internal image: use target_part as key
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
@ -261,9 +311,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
append_image_link(image_id, has_drawing)
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
@ -271,9 +319,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
append_image_link(image_id, has_drawing)
if run.text.strip():
paragraph_content.append(run.text.strip())
return "".join(paragraph_content) if paragraph_content else ""

View File

@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum):
notion_import = "notion"
local_file = "file_upload"
online_document = "online_document"
online_drive = "online_drive"

View File

@ -209,7 +209,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if all_multimodal_documents:
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)

View File

@ -312,7 +312,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector = Vector(dataset)
if all_child_documents:
vector.create(all_child_documents)
if all_multimodal_documents:
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:

View File

@ -266,7 +266,7 @@ class DatasetRetrieval:
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attchment_info = File(
attachment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
@ -280,7 +280,7 @@ class DatasetRetrieval:
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attchment_info)
context_files.append(attachment_info)
if show_retrieve_source:
for record in records:
segment = record.segment
@ -592,111 +592,116 @@ class DatasetRetrieval:
"""Handle retrieval end."""
with flask_app.app_context():
dify_documents = [document for document in documents if document.provider == "dify"]
segment_ids = []
segment_index_node_ids = []
if not dify_documents:
self._send_trace_task(message_id, documents, timer)
return
with Session(db.engine) as session:
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
segment_id = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = session.scalar(child_chunk_stmt)
if child_chunk:
segment_id = child_chunk.segment_id
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
# Collect all document_ids and batch fetch DatasetDocuments
document_ids = {
doc.metadata["document_id"]
for doc in dify_documents
if doc.metadata and "document_id" in doc.metadata
}
if not document_ids:
self._send_trace_task(message_id, documents, timer)
return
dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
dataset_docs = session.scalars(dataset_docs_stmt).all()
dataset_doc_map = {str(doc.id): doc for doc in dataset_docs}
# Categorize documents by type and collect necessary IDs
parent_child_text_docs: list[tuple[Document, DatasetDocument]] = []
parent_child_image_docs: list[tuple[Document, DatasetDocument]] = []
normal_text_docs: list[tuple[Document, DatasetDocument]] = []
normal_image_docs: list[tuple[Document, DatasetDocument]] = []
for doc in dify_documents:
if not doc.metadata or "document_id" not in doc.metadata:
continue
dataset_doc = dataset_doc_map.get(doc.metadata["document_id"])
if not dataset_doc:
continue
is_image = doc.metadata.get("doc_type") == DocType.IMAGE
is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX
if is_parent_child:
if is_image:
parent_child_image_docs.append((doc, dataset_doc))
else:
parent_child_text_docs.append((doc, dataset_doc))
else:
if is_image:
normal_image_docs.append((doc, dataset_doc))
else:
normal_text_docs.append((doc, dataset_doc))
segment_ids_to_update: set[str] = set()
# Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks
if parent_child_text_docs:
index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata]
if index_node_ids:
child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids))
child_chunks = session.scalars(child_chunks_stmt).all()
child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks}
for doc, _ in parent_child_text_docs:
if doc.metadata:
segment_id = child_chunk_map.get(doc.metadata["doc_id"])
if segment_id:
if segment_id not in segment_ids:
segment_ids.append(segment_id)
_ = (
session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
if document.metadata["doc_id"] not in segment_index_node_ids:
segment = (
session.query(DocumentSegment)
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
.first()
)
if segment:
segment_index_node_ids.append(document.metadata["doc_id"])
segment_ids.append(segment.id)
query = session.query(DocumentSegment).where(
DocumentSegment.id == segment.id
)
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id not in segment_ids:
segment_ids.append(segment_id)
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
if query:
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(
DocumentSegment.dataset_id == document.metadata["dataset_id"]
)
segment_ids_to_update.add(str(segment_id))
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
# Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments
if normal_text_docs:
index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata]
if index_node_ids:
segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids))
segments = session.scalars(segments_stmt).all()
segment_map = {seg.index_node_id: seg.id for seg in segments}
for doc, _ in normal_text_docs:
if doc.metadata:
segment_id = segment_map.get(doc.metadata["doc_id"])
if segment_id:
segment_ids_to_update.add(str(segment_id))
db.session.commit()
# Process IMAGE documents - batch fetch SegmentAttachmentBindings
all_image_docs = parent_child_image_docs + normal_image_docs
if all_image_docs:
attachment_ids = [
doc.metadata["doc_id"]
for doc, _ in all_image_docs
if doc.metadata and doc.metadata.get("doc_id")
]
if attachment_ids:
bindings_stmt = select(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.attachment_id.in_(attachment_ids)
)
bindings = session.scalars(bindings_stmt).all()
segment_ids_to_update.update(str(binding.segment_id) for binding in bindings)
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
# Batch update hit_count for all segments
if segment_ids_to_update:
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
session.commit()
self._send_trace_task(message_id, documents, timer)
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
"""Send trace task if trace manager is available."""
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
)
def _on_query(
self,

View File

@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError):
pass
class ToolSSRFError(ValueError):
pass
class ToolCredentialPolicyViolationError(ValueError):
pass

View File

@ -101,6 +101,8 @@ class ToolFileMessageTransformer:
meta = message.meta or {}
mimetype = meta.get("mime_type", "application/octet-stream")
if not mimetype:
mimetype = "application/octet-stream"
# get filename from meta
filename = meta.get("filename", None)
# if message is str, encode it to bytes

View File

@ -425,7 +425,7 @@ class ApiBasedToolSchemaParser:
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
# openapi parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
@ -436,7 +436,6 @@ class ApiBasedToolSchemaParser:
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(

View File

@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str:
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F\"#$%&'()*+,./:;<=>?@^_`~]+"
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
return re.sub(pattern, "", text)

View File

@ -221,7 +221,7 @@ class WorkflowToolProviderController(ToolProviderController):
session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
WorkflowToolProvider.id == self.provider_id,
)
.first()
)

View File

@ -140,6 +140,10 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
flask_app: Flask | None = None
@ -158,6 +162,7 @@ class GraphEngine:
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
@ -196,10 +201,6 @@ class GraphEngine:
event_emitter=self._event_manager,
)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Validation ===
# Ensure all nodes share the same GraphRuntimeState instance
self._validate_graph_state_consistency()

View File

@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
from .observability import ObservabilityLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
"ObservabilityLayer",
]

View File

@ -9,6 +9,7 @@ from abc import ABC, abstractmethod
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
@ -83,3 +84,29 @@ class GraphEngineLayer(ABC):
error: The exception that caused execution to fail, or None if successful
"""
pass
def on_node_run_start(self, node: Node) -> None: # noqa: B027
"""
Called immediately before a node begins execution.
Layers can override to inject behavior (e.g., start spans) prior to node execution.
The node's execution ID is available via `node._node_execution_id` and will be
consistent with all events emitted by this node execution.
Args:
node: The node instance about to be executed
"""
pass
def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
"""
Called after a node finishes execution.
The node's execution ID is available via `node._node_execution_id` and matches
the `id` field in all events emitted by this node execution.
Args:
node: The node instance that just finished execution
error: Exception instance if the node failed, otherwise None
"""
pass

View File

@ -0,0 +1,61 @@
"""
Node-level OpenTelemetry parser interfaces and defaults.
"""
import json
from typing import Protocol
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.tool.entities import ToolNodeData
class NodeOTelParser(Protocol):
"""Parser interface for node-specific OpenTelemetry enrichment."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
class DefaultNodeOTelParser:
"""Fallback parser used when no node-specific parser is registered."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
span.set_attribute("node.id", node.id)
if node.execution_id:
span.set_attribute("node.execution_id", node.execution_id)
if hasattr(node, "node_type") and node.node_type:
span.set_attribute("node.type", node.node_type.value)
if error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
else:
span.set_status(Status(StatusCode.OK))
class ToolNodeOTelParser:
"""Parser for tool nodes that captures tool-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
self._delegate.parse(node=node, span=span, error=error)
tool_data = getattr(node, "_node_data", None)
if not isinstance(tool_data, ToolNodeData):
return
span.set_attribute("tool.provider.id", tool_data.provider_id)
span.set_attribute("tool.provider.type", tool_data.provider_type.value)
span.set_attribute("tool.provider.name", tool_data.provider_name)
span.set_attribute("tool.name", tool_data.tool_name)
span.set_attribute("tool.label", tool_data.tool_label)
if tool_data.plugin_unique_identifier:
span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
if tool_data.credential_id:
span.set_attribute("tool.credential.id", tool_data.credential_id)
if tool_data.tool_configurations:
span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))

View File

@ -0,0 +1,169 @@
"""
Observability layer for GraphEngine.
This layer creates OpenTelemetry spans for node execution, enabling distributed
tracing of workflow execution. It establishes OTel context during node execution
so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically
associates with the node span.
"""
import logging
from dataclasses import dataclass
from typing import cast, final
from opentelemetry import context as context_api
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
from typing_extensions import override
from configs import dify_config
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.node_parsers import (
DefaultNodeOTelParser,
NodeOTelParser,
ToolNodeOTelParser,
)
from core.workflow.nodes.base.node import Node
from extensions.otel.runtime import is_instrument_flag_enabled
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class _NodeSpanContext:
span: "Span"
token: object
@final
class ObservabilityLayer(GraphEngineLayer):
"""
Layer that creates OpenTelemetry spans for node execution.
This layer:
- Creates a span when a node starts execution
- Establishes OTel context so automatic instrumentation associates with the span
- Sets complete attributes and status when node execution ends
"""
def __init__(self) -> None:
super().__init__()
self._node_contexts: dict[str, _NodeSpanContext] = {}
self._parsers: dict[NodeType, NodeOTelParser] = {}
self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser())
self._is_disabled: bool = False
self._tracer: Tracer | None = None
self._build_parser_registry()
self._init_tracer()
def _init_tracer(self) -> None:
"""Initialize OpenTelemetry tracer in constructor."""
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
self._is_disabled = True
return
try:
self._tracer = get_tracer(__name__)
except Exception as e:
logger.warning("Failed to get OpenTelemetry tracer: %s", e)
self._is_disabled = True
def _build_parser_registry(self) -> None:
"""Initialize parser registry for node types."""
self._parsers = {
NodeType.TOOL: ToolNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
return self._parsers.get(node_type, self._default_parser)
return self._default_parser
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self._node_contexts.clear()
@override
def on_node_run_start(self, node: Node) -> None:
"""
Called when a node starts execution.
Creates a span and establishes OTel context for automatic instrumentation.
"""
if self._is_disabled:
return
try:
if not self._tracer:
return
execution_id = node.execution_id
if not execution_id:
return
parent_context = context_api.get_current()
span = self._tracer.start_span(
f"{node.title}",
kind=SpanKind.INTERNAL,
context=parent_context,
)
new_context = set_span_in_context(span)
token = context_api.attach(new_context)
self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token)
except Exception as e:
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_node_run_end(self, node: Node, error: Exception | None) -> None:
"""
Called when a node finishes execution.
Sets complete attributes, records exceptions, and ends the span.
"""
if self._is_disabled:
return
try:
execution_id = node.execution_id
if not execution_id:
return
node_context = self._node_contexts.get(execution_id)
if not node_context:
return
span = node_context.span
parser = self._get_parser(node)
try:
parser.parse(node=node, span=span, error=error)
span.end()
finally:
token = node_context.token
if token is not None:
try:
context_api.detach(token)
except Exception:
logger.warning("Failed to detach OpenTelemetry token: %s", token)
self._node_contexts.pop(execution_id, None)
except Exception as e:
logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_event(self, event) -> None:
"""Not used in this layer."""
pass
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._node_contexts:
logger.warning(
"ObservabilityLayer: %d node spans were not properly ended",
len(self._node_contexts),
)
self._node_contexts.clear()

View File

@ -9,6 +9,7 @@ import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
from typing import final
from uuid import uuid4
@ -17,6 +18,7 @@ from flask import Flask
from typing_extensions import override
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
@ -39,6 +41,7 @@ class Worker(threading.Thread):
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
@ -50,6 +53,7 @@ class Worker(threading.Thread):
ready_queue: Ready queue containing node IDs ready for execution
event_queue: Queue for pushing execution events
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
@ -63,6 +67,7 @@ class Worker(threading.Thread):
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
self._layers = layers if layers is not None else []
def stop(self) -> None:
"""Signal the worker to stop processing."""
@ -122,20 +127,51 @@ class Worker(threading.Thread):
Args:
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
node.ensure_execution_id()
error: Exception | None = None
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
else:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
def _invoke_node_run_start_hooks(self, node: Node) -> None:
"""Invoke on_node_run_start hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_start(node)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
"""Invoke on_node_run_end hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_end(node, error)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue

View File

@ -14,6 +14,7 @@ from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..layers.base import GraphEngineLayer
from ..ready_queue import ReadyQueue
from ..worker import Worker
@ -39,6 +40,7 @@ class WorkerPool:
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
@ -53,6 +55,7 @@ class WorkerPool:
ready_queue: Ready queue for nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
@ -65,6 +68,7 @@ class WorkerPool:
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
self._layers = layers
# Scaling parameters with defaults
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
@ -144,6 +148,7 @@ class WorkerPool:
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,

View File

@ -59,7 +59,7 @@ class OutputVariableEntity(BaseModel):
"""
variable: str
value_type: OutputVariableType
value_type: OutputVariableType = OutputVariableType.ANY
value_selector: Sequence[str]
@field_validator("value_type", mode="before")

View File

@ -244,6 +244,15 @@ class Node(Generic[NodeDataT]):
def graph_init_params(self) -> "GraphInitParams":
return self._graph_init_params
@property
def execution_id(self) -> str:
return self._node_execution_id
def ensure_execution_id(self) -> str:
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
return self._node_execution_id
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@ -256,14 +265,12 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError
def run(self) -> Generator[GraphNodeEventBase, None, None]:
# Generate a single node execution ID to use for all events
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=self._node_execution_id,
id=execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.title,
@ -321,7 +328,7 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self._node_execution_id
event.id = self.execution_id
yield event
else:
yield event
@ -333,7 +340,7 @@ class Node(Generic[NodeDataT]):
error_type="WorkflowNodeError",
)
yield NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -512,7 +519,7 @@ class Node(Generic[NodeDataT]):
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@ -521,7 +528,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@ -537,7 +544,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
@ -550,7 +557,7 @@ class Node(Generic[NodeDataT]):
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -558,7 +565,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -573,7 +580,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
@ -583,7 +590,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
message_id=event.message_id,
@ -599,7 +606,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -612,7 +619,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -623,7 +630,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -637,7 +644,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -652,7 +659,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -665,7 +672,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -676,7 +683,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -690,7 +697,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -705,7 +712,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=event.retriever_resources,

View File

@ -412,16 +412,20 @@ class Executor:
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
# decode content safely
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
body_string += content.decode("utf-8", errors="replace")
body_string += "\r\n"
# Do not decode binary content; use a placeholder with file metadata instead.
# Includes filename, size, and MIME type for better logging context.
body_string += (
f"<file_content_binary: '{file_entry[1][0] or 'unknown'}', "
f"type='{file_entry[1][2] if len(file_entry[1]) > 2 else 'unknown'}', "
f"size={len(content)} bytes>\r\n"
)
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
# If content is bytes, do not decode it; show a placeholder with size.
# Provides content size information for binary data without exposing the raw bytes.
if isinstance(self.content, bytes):
body_string = self.content.decode("utf-8", errors="replace")
body_string = f"<binary_content: size={len(self.content)} bytes>"
else:
body_string = self.content
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":

View File

@ -334,6 +334,7 @@ class LLMNode(Node[LLMNodeData]):
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
llm_usage=usage,
)
)
except Exception as e:
@ -344,6 +345,8 @@ class LLMNode(Node[LLMNodeData]):
error=str(e),
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
llm_usage=usage,
)
)
@ -694,7 +697,7 @@ class LLMNode(Node[LLMNodeData]):
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attchment_info = File(
attachment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
@ -708,7 +711,7 @@ class LLMNode(Node[LLMNodeData]):
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attchment_info)
context_files.append(attachment_info)
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource,
context=context_str.strip(),

View File

@ -221,6 +221,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,

View File

@ -1,14 +1,22 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.file import FileTransferMethod
from core.variables.types import SegmentType
from core.variables.variables import FileVariable
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from factories import file_factory
from factories.variable_factory import build_segment_with_type
from .entities import ContentType, WebhookData
logger = logging.getLogger(__name__)
class TriggerWebhookNode(Node[WebhookData]):
node_type = NodeType.TRIGGER_WEBHOOK
@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs=outputs,
)
def generate_file_var(self, param_name: str, file: dict):
related_id = file.get("related_id")
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL:
file["upload_file_id"] = related_id
case FileTransferMethod.TOOL_FILE:
file["tool_file_id"] = related_id
case FileTransferMethod.DATASOURCE_FILE:
file["datasource_file_id"] = related_id
try:
file_obj = file_factory.build_from_mapping(
mapping=file,
tenant_id=self.tenant_id,
)
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
except ValueError:
logger.error(
"Failed to build FileVariable for webhook file parameter %s",
param_name,
exc_info=True,
)
return None
def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]:
"""Extract outputs based on node configuration from webhook inputs."""
outputs = {}
@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self.node_data.content_type == ContentType.BINARY:
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
file_var = self.generate_file_var(param_name, raw_data)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = raw_data
continue
if param_type == "file":
# Get File object (already processed by webhook controller)
file_obj = webhook_data.get("files", {}).get(param_name)
outputs[param_name] = file_obj
files = webhook_data.get("files", {})
if files and isinstance(files, dict):
file = files.get(param_name)
if file and isinstance(file, dict):
file_var = self.generate_file_var(param_name, file)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
# Include raw webhook data for debugging/advanced use
outputs["_webhook_raw"] = webhook_data
return outputs

View File

@ -14,7 +14,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
@ -23,6 +23,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from extensions.otel.runtime import is_instrument_flag_enabled
from factories import file_factory
from models.enums import UserFrom
from models.workflow import Workflow
@ -98,6 +99,10 @@ class WorkflowEntry:
)
self.graph_engine.layer(limits_layer)
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
self.graph_engine.layer(ObservabilityLayer())
def run(self) -> Generator[GraphEngineEvent, None, None]:
graph_engine = self.graph_engine

View File

@ -15,4 +15,5 @@ def handle(sender: Dataset, **kwargs):
dataset.index_struct,
dataset.collection_binding_id,
dataset.doc_form,
dataset.pipeline_id,
)

View File

@ -9,11 +9,21 @@ FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
def init_app(app: DifyApp):
# register blueprint routers
def _apply_cors_once(bp, /, **cors_kwargs):
"""Make CORS idempotent so blueprints can be reused across multiple app instances."""
if getattr(bp, "_dify_cors_applied", False):
return
from flask_cors import CORS
CORS(bp, **cors_kwargs)
bp._dify_cors_applied = True
def init_app(app: DifyApp):
# register blueprint routers
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp
@ -22,7 +32,7 @@ def init_app(app: DifyApp):
from controllers.trigger import bp as trigger_bp
from controllers.web import bp as web_bp
CORS(
_apply_cors_once(
service_api_bp,
allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
@ -30,7 +40,7 @@ def init_app(app: DifyApp):
)
app.register_blueprint(service_api_bp)
CORS(
_apply_cors_once(
web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
@ -40,7 +50,7 @@ def init_app(app: DifyApp):
)
app.register_blueprint(web_bp)
CORS(
_apply_cors_once(
console_app_bp,
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
@ -50,7 +60,7 @@ def init_app(app: DifyApp):
)
app.register_blueprint(console_app_bp)
CORS(
_apply_cors_once(
files_bp,
allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
@ -62,7 +72,7 @@ def init_app(app: DifyApp):
app.register_blueprint(mcp_bp)
# Register trigger blueprint with CORS for webhook calls
CORS(
_apply_cors_once(
trigger_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],

View File

@ -0,0 +1,74 @@
"""
Logstore extension for Dify application.
This extension initializes the logstore (Aliyun SLS) on application startup,
creating necessary projects, logstores, and indexes if they don't exist.
"""
import logging
import os
from dotenv import load_dotenv
from dify_app import DifyApp
logger = logging.getLogger(__name__)
def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
Returns:
True if all required Aliyun SLS environment variables are set, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
"ALIYUN_SLS_ENDPOINT",
"ALIYUN_SLS_REGION",
"ALIYUN_SLS_PROJECT_NAME",
]
all_set = all(os.environ.get(var) for var in required_vars)
if not all_set:
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
return all_set
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
This function:
1. Creates Aliyun SLS project if it doesn't exist
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
3. Creates indexes with field configurations based on PostgreSQL table structures
This operation is idempotent and only executes once during application startup.
Args:
app: The Dify application instance
"""
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
logger.info("Initializing logstore...")
# Create logstore client and initialize project/logstores/indexes
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
# Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
except Exception:
logger.exception("Failed to initialize logstore")
# Don't raise - allow application to continue even if logstore init fails
# This ensures that the application can still run if logstore is misconfigured

View File

@ -0,0 +1,7 @@
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
def init_app(app):
with app.app_context():
configure_session_factory(db.engine)

View File

View File

@ -0,0 +1,890 @@
import logging
import os
import threading
import time
from collections.abc import Sequence
from typing import Any
import sqlalchemy as sa
from aliyun.log import ( # type: ignore[import-untyped]
GetLogsRequest,
IndexConfig,
IndexKeyConfig,
IndexLineConfig,
LogClient,
LogItem,
PutLogsRequest,
)
from aliyun.log.auth import AUTH_VERSION_4 # type: ignore[import-untyped]
from aliyun.log.logexception import LogException # type: ignore[import-untyped]
from dotenv import load_dotenv
from sqlalchemy.orm import DeclarativeBase
from configs import dify_config
from extensions.logstore.aliyun_logstore_pg import AliyunLogStorePG
logger = logging.getLogger(__name__)
class AliyunLogStore:
"""
Singleton class for Aliyun SLS LogStore operations.
Ensures only one instance exists to prevent multiple PG connection pools.
"""
_instance: "AliyunLogStore | None" = None
_initialized: bool = False
# Track delayed PG connection for newly created projects
_pg_connection_timer: threading.Timer | None = None
_pg_connection_delay: int = 90 # delay seconds
# Default tokenizer for text/json fields and full-text index
# Common delimiters: comma, space, quotes, punctuation, operators, brackets, special chars
DEFAULT_TOKEN_LIST = [
",",
" ",
'"',
'"',
";",
"=",
"(",
")",
"[",
"]",
"{",
"}",
"?",
"@",
"&",
"<",
">",
"/",
":",
"\n",
"\t",
]
def __new__(cls) -> "AliyunLogStore":
"""Implement singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
project_des = "dify"
workflow_execution_logstore = "workflow_execution"
workflow_node_execution_logstore = "workflow_node_execution"
@staticmethod
def _sqlalchemy_type_to_logstore_type(column: Any) -> str:
"""
Map SQLAlchemy column type to Aliyun LogStore index type.
Args:
column: SQLAlchemy column object
Returns:
LogStore index type: 'text', 'long', 'double', or 'json'
"""
column_type = column.type
# Integer types -> long
if isinstance(column_type, (sa.Integer, sa.BigInteger, sa.SmallInteger)):
return "long"
# Float types -> double
if isinstance(column_type, (sa.Float, sa.Numeric)):
return "double"
# String and Text types -> text
if isinstance(column_type, (sa.String, sa.Text)):
return "text"
# DateTime -> text (stored as ISO format string in logstore)
if isinstance(column_type, sa.DateTime):
return "text"
# Boolean -> long (stored as 0/1)
if isinstance(column_type, sa.Boolean):
return "long"
# JSON -> json
if isinstance(column_type, sa.JSON):
return "json"
# Default to text for unknown types
return "text"
@staticmethod
def _generate_index_keys_from_model(model_class: type[DeclarativeBase]) -> dict[str, IndexKeyConfig]:
"""
Automatically generate LogStore field index configuration from SQLAlchemy model.
This method introspects the SQLAlchemy model's column definitions and creates
corresponding LogStore index configurations. When the PG schema is updated via
Flask-Migrate, this method will automatically pick up the new fields on next startup.
Args:
model_class: SQLAlchemy model class (e.g., WorkflowRun, WorkflowNodeExecutionModel)
Returns:
Dictionary mapping field names to IndexKeyConfig objects
"""
index_keys = {}
# Iterate over all mapped columns in the model
if hasattr(model_class, "__mapper__"):
for column_name, column_property in model_class.__mapper__.columns.items():
# Skip relationship properties and other non-column attributes
if not hasattr(column_property, "type"):
continue
# Map SQLAlchemy type to LogStore type
logstore_type = AliyunLogStore._sqlalchemy_type_to_logstore_type(column_property)
# Create index configuration
# - text fields: case_insensitive for better search, with tokenizer and Chinese support
# - all fields: doc_value=True for analytics
if logstore_type == "text":
index_keys[column_name] = IndexKeyConfig(
index_type="text",
case_sensitive=False,
doc_value=True,
token_list=AliyunLogStore.DEFAULT_TOKEN_LIST,
chinese=True,
)
else:
index_keys[column_name] = IndexKeyConfig(index_type=logstore_type, doc_value=True)
# Add log_version field (not in PG model, but used in logstore for versioning)
index_keys["log_version"] = IndexKeyConfig(index_type="long", doc_value=True)
return index_keys
def __init__(self) -> None:
# Skip initialization if already initialized (singleton pattern)
if self.__class__._initialized:
return
load_dotenv()
self.access_key_id: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_ID", "")
self.access_key_secret: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_SECRET", "")
self.endpoint: str = os.environ.get("ALIYUN_SLS_ENDPOINT", "")
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
)
# Append Dify identification to the existing user agent
original_user_agent = self.client._user_agent # pyright: ignore[reportPrivateUsage]
dify_version = dify_config.project.version
enhanced_user_agent = f"Dify,Dify-{dify_version},{original_user_agent}"
self.client.set_user_agent(enhanced_user_agent)
# PG client will be initialized in init_project_logstore
self._pg_client: AliyunLogStorePG | None = None
self._use_pg_protocol: bool = False
self.__class__._initialized = True
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
return self._use_pg_protocol
def _attempt_pg_connection_init(self) -> bool:
"""
Attempt to initialize PG connection.
This method tries to establish PG connection and performs necessary checks.
It's used both for immediate connection (existing projects) and delayed connection (new projects).
Returns:
True if PG connection was successfully established, False otherwise.
"""
if not self.pg_mode_enabled or not self._pg_client:
return False
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
return False
except Exception as e:
logger.warning(
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
self.project_name,
str(e),
)
self._use_pg_protocol = False
return False
def _delayed_pg_connection_init(self) -> None:
"""
Delayed initialization of PG connection for newly created projects.
This method is called by a background timer 3 minutes after project creation.
"""
# Double check conditions in case state changed
if self._use_pg_protocol:
return
logger.info(
"Attempting delayed PG connection for newly created project %s ...",
self.project_name,
)
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
def init_project_logstore(self):
"""
Initialize project, logstore, index, and PG connection.
This method should be called once during application startup to ensure
all required resources exist and connections are established.
"""
# Step 1: Ensure project and logstore exist
project_is_new = False
if not self.is_project_exist():
self.create_project()
project_is_new = True
self.create_logstore_if_not_exist()
# Step 2: Initialize PG client and connection (if enabled)
if not self.pg_mode_enabled:
logger.info("PG mode is disabled. Will use SDK mode.")
return
# Create PG client if not already created
if self._pg_client is None:
logger.info("Initializing PG client for project %s...", self.project_name)
self._pg_client = AliyunLogStorePG(
self.access_key_id, self.access_key_secret, self.endpoint, self.project_name
)
# Step 3: Establish PG connection based on project status
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
logger.info(
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
self.project_name,
self.__class__._pg_connection_delay,
)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
self.__class__._pg_connection_delay,
self._delayed_pg_connection_init,
)
self.__class__._pg_connection_timer.daemon = True # Don't block app shutdown
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
"""
Check if scan_index is enabled for all logstores.
If any logstore has scan_index=false, disable PG protocol.
This is necessary because PG protocol requires scan_index to be enabled.
"""
logstore_name_list = [
AliyunLogStore.workflow_execution_logstore,
AliyunLogStore.workflow_node_execution_logstore,
]
for logstore_name in logstore_name_list:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
"PG protocol requires scan_index to be enabled.",
logstore_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
if self._pg_client:
self._pg_client.close()
self._pg_client = None
return
def is_project_exist(self) -> bool:
try:
self.client.get_project(self.project_name)
return True
except Exception as e:
if e.args[0] == "ProjectNotExist":
return False
else:
raise e
def create_project(self):
try:
self.client.create_project(self.project_name, AliyunLogStore.project_des)
logger.info("Project %s created successfully", self.project_name)
except LogException as e:
logger.exception(
"Failed to create project %s: errorCode=%s, errorMessage=%s, requestId=%s",
self.project_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def is_logstore_exist(self, logstore_name: str) -> bool:
try:
_ = self.client.get_logstore(self.project_name, logstore_name)
return True
except Exception as e:
if e.args[0] == "LogStoreNotExist":
return False
else:
raise e
def create_logstore_if_not_exist(self) -> None:
logstore_name_list = [
AliyunLogStore.workflow_execution_logstore,
AliyunLogStore.workflow_node_execution_logstore,
]
for logstore_name in logstore_name_list:
if not self.is_logstore_exist(logstore_name):
try:
self.client.create_logstore(
project_name=self.project_name, logstore_name=logstore_name, ttl=self.logstore_ttl
)
logger.info("logstore %s created successfully", logstore_name)
except LogException as e:
logger.exception(
"Failed to create logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
# Ensure index contains all Dify-required fields
# This intelligently merges with existing config, preserving custom indexes
self.ensure_index_config(logstore_name)
def is_index_exist(self, logstore_name: str) -> bool:
try:
_ = self.client.get_index_config(self.project_name, logstore_name)
return True
except Exception as e:
if e.args[0] == "IndexConfigNotExist":
return False
else:
raise e
def get_existing_index_config(self, logstore_name: str) -> IndexConfig | None:
"""
Get existing index configuration from logstore.
Args:
logstore_name: Name of the logstore
Returns:
IndexConfig object if index exists, None otherwise
"""
try:
response = self.client.get_index_config(self.project_name, logstore_name)
return response.get_index_config()
except Exception as e:
if e.args[0] == "IndexConfigNotExist":
return None
else:
logger.exception("Failed to get index config for logstore %s", logstore_name)
raise e
def _get_workflow_execution_index_keys(self) -> dict[str, IndexKeyConfig]:
"""
Get field index configuration for workflow_execution logstore.
This method automatically generates index configuration from the WorkflowRun SQLAlchemy model.
When the PG schema is updated via Flask-Migrate, the index configuration will be automatically
updated on next application startup.
"""
from models.workflow import WorkflowRun
index_keys = self._generate_index_keys_from_model(WorkflowRun)
# Add custom fields that are in logstore but not in PG model
# These fields are added by the repository layer
index_keys["error_message"] = IndexKeyConfig(
index_type="text",
case_sensitive=False,
doc_value=True,
token_list=self.DEFAULT_TOKEN_LIST,
chinese=True,
) # Maps to 'error' in PG
index_keys["started_at"] = IndexKeyConfig(
index_type="text",
case_sensitive=False,
doc_value=True,
token_list=self.DEFAULT_TOKEN_LIST,
chinese=True,
) # Maps to 'created_at' in PG
logger.info("Generated %d index keys for workflow_execution from WorkflowRun model", len(index_keys))
return index_keys
def _get_workflow_node_execution_index_keys(self) -> dict[str, IndexKeyConfig]:
"""
Get field index configuration for workflow_node_execution logstore.
This method automatically generates index configuration from the WorkflowNodeExecutionModel.
When the PG schema is updated via Flask-Migrate, the index configuration will be automatically
updated on next application startup.
"""
from models.workflow import WorkflowNodeExecutionModel
index_keys = self._generate_index_keys_from_model(WorkflowNodeExecutionModel)
logger.debug(
"Generated %d index keys for workflow_node_execution from WorkflowNodeExecutionModel", len(index_keys)
)
return index_keys
def _get_index_config(self, logstore_name: str) -> IndexConfig:
"""
Get index configuration for the specified logstore.
Args:
logstore_name: Name of the logstore
Returns:
IndexConfig object with line and field indexes
"""
# Create full-text index (line config) with tokenizer
line_config = IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True)
# Get field index configuration based on logstore name
field_keys = {}
if logstore_name == AliyunLogStore.workflow_execution_logstore:
field_keys = self._get_workflow_execution_index_keys()
elif logstore_name == AliyunLogStore.workflow_node_execution_logstore:
field_keys = self._get_workflow_node_execution_index_keys()
# key_config_list should be a dict, not a list
# Create index config with both line and field indexes
return IndexConfig(line_config=line_config, key_config_list=field_keys, scan_index=True)
def create_index(self, logstore_name: str) -> None:
"""
Create index for the specified logstore with both full-text and field indexes.
Field indexes are automatically generated from the corresponding SQLAlchemy model.
"""
index_config = self._get_index_config(logstore_name)
try:
self.client.create_index(self.project_name, logstore_name, index_config)
logger.info(
"index for %s created successfully with %d field indexes",
logstore_name,
len(index_config.key_config_list or {}),
)
except LogException as e:
logger.exception(
"Failed to create index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def _merge_index_configs(
self, existing_config: IndexConfig, required_keys: dict[str, IndexKeyConfig], logstore_name: str
) -> tuple[IndexConfig, bool]:
"""
Intelligently merge existing index config with Dify's required field indexes.
This method:
1. Preserves all existing field indexes in logstore (including custom fields)
2. Adds missing Dify-required fields
3. Updates fields where type doesn't match (with json/text compatibility)
4. Corrects case mismatches (e.g., if Dify needs 'status' but logstore has 'Status')
Type compatibility rules:
- json and text types are considered compatible (users can manually choose either)
- All other type mismatches will be corrected to match Dify requirements
Note: Logstore is case-sensitive and doesn't allow duplicate fields with different cases.
Case mismatch means: existing field name differs from required name only in case.
Args:
existing_config: Current index configuration from logstore
required_keys: Dify's required field index configurations
logstore_name: Name of the logstore (for logging)
Returns:
Tuple of (merged_config, needs_update)
"""
# key_config_list is already a dict in the SDK
# Make a copy to avoid modifying the original
existing_keys = dict(existing_config.key_config_list) if existing_config.key_config_list else {}
# Track changes
needs_update = False
case_corrections = [] # Fields that need case correction (e.g., 'Status' -> 'status')
missing_fields = []
type_mismatches = []
# First pass: Check for and resolve case mismatches with required fields
# Note: Logstore itself doesn't allow duplicate fields with different cases,
# so we only need to check if the existing case matches the required case
for required_name in required_keys:
lower_name = required_name.lower()
# Find key that matches case-insensitively but not exactly
wrong_case_key = None
for existing_key in existing_keys:
if existing_key.lower() == lower_name and existing_key != required_name:
wrong_case_key = existing_key
break
if wrong_case_key:
# Field exists but with wrong case (e.g., 'Status' when we need 'status')
# Remove the wrong-case key, will be added back with correct case later
case_corrections.append((wrong_case_key, required_name))
del existing_keys[wrong_case_key]
needs_update = True
# Second pass: Check each required field
for required_name, required_config in required_keys.items():
# Check for exact match (case-sensitive)
if required_name in existing_keys:
existing_type = existing_keys[required_name].index_type
required_type = required_config.index_type
# Check if type matches
# Special case: json and text are interchangeable for JSON content fields
# Allow users to manually configure text instead of json (or vice versa) without forcing updates
is_compatible = existing_type == required_type or ({existing_type, required_type} == {"json", "text"})
if not is_compatible:
type_mismatches.append((required_name, existing_type, required_type))
# Update with correct type
existing_keys[required_name] = required_config
needs_update = True
# else: field exists with compatible type, no action needed
else:
# Field doesn't exist (may have been removed in first pass due to case conflict)
missing_fields.append(required_name)
existing_keys[required_name] = required_config
needs_update = True
# Log changes
if missing_fields:
logger.info(
"Logstore %s: Adding %d missing Dify-required fields: %s",
logstore_name,
len(missing_fields),
", ".join(missing_fields[:10]) + ("..." if len(missing_fields) > 10 else ""),
)
if type_mismatches:
logger.info(
"Logstore %s: Fixing %d type mismatches: %s",
logstore_name,
len(type_mismatches),
", ".join([f"{name}({old}->{new})" for name, old, new in type_mismatches[:5]])
+ ("..." if len(type_mismatches) > 5 else ""),
)
if case_corrections:
logger.info(
"Logstore %s: Correcting %d field name cases: %s",
logstore_name,
len(case_corrections),
", ".join([f"'{old}' -> '{new}'" for old, new in case_corrections[:5]])
+ ("..." if len(case_corrections) > 5 else ""),
)
# Create merged config
# key_config_list should be a dict, not a list
# Preserve the original scan_index value - don't force it to True
merged_config = IndexConfig(
line_config=existing_config.line_config
or IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True),
key_config_list=existing_keys,
scan_index=existing_config.scan_index,
)
return merged_config, needs_update
def ensure_index_config(self, logstore_name: str) -> None:
"""
Ensure index configuration includes all Dify-required fields.
This method intelligently manages index configuration:
1. If index doesn't exist, create it with Dify's required fields
2. If index exists:
- Check if all Dify-required fields are present
- Check if field types match requirements
- Only update if fields are missing or types are incorrect
- Preserve any additional custom index configurations
This approach allows users to add their own custom indexes without being overwritten.
"""
# Get Dify's required field indexes
required_keys = {}
if logstore_name == AliyunLogStore.workflow_execution_logstore:
required_keys = self._get_workflow_execution_index_keys()
elif logstore_name == AliyunLogStore.workflow_node_execution_logstore:
required_keys = self._get_workflow_node_execution_index_keys()
# Check if index exists
existing_config = self.get_existing_index_config(logstore_name)
if existing_config is None:
# Index doesn't exist, create it
logger.info(
"Logstore %s: Index doesn't exist, creating with %d required fields",
logstore_name,
len(required_keys),
)
self.create_index(logstore_name)
else:
merged_config, needs_update = self._merge_index_configs(existing_config, required_keys, logstore_name)
if needs_update:
logger.info("Logstore %s: Updating index to include Dify-required fields", logstore_name)
try:
self.client.update_index(self.project_name, logstore_name, merged_config)
logger.info(
"Logstore %s: Index updated successfully, now has %d total field indexes",
logstore_name,
len(merged_config.key_config_list or {}),
)
except LogException as e:
logger.exception(
"Failed to update index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
else:
logger.info(
"Logstore %s: Index already contains all %d Dify-required fields with correct types, "
"no update needed",
logstore_name,
len(required_keys),
)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]]) -> None:
# Route to PG or SDK based on protocol availability
if self._use_pg_protocol and self._pg_client:
self._pg_client.put_log(logstore, contents, self.log_enabled)
else:
log_item = LogItem(contents=contents)
request = PutLogsRequest(project=self.project_name, logstore=logstore, logitems=[log_item])
if self.log_enabled:
logger.info(
"[LogStore-SDK] PUT_LOG | logstore=%s | project=%s | items_count=%d",
logstore,
self.project_name,
len(contents),
)
try:
self.client.put_logs(request)
except LogException as e:
logger.exception(
"Failed to put logs to logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def get_logs(
self,
logstore: str,
from_time: int,
to_time: int,
topic: str = "",
query: str = "",
line: int = 100,
offset: int = 0,
reverse: bool = True,
) -> list[dict]:
request = GetLogsRequest(
project=self.project_name,
logstore=logstore,
fromTime=from_time,
toTime=to_time,
topic=topic,
query=query,
line=line,
offset=offset,
reverse=reverse,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
"from_time=%d | to_time=%d | line=%d | offset=%d | reverse=%s",
logstore,
self.project_name,
query,
from_time,
to_time,
line,
offset,
reverse,
)
try:
response = self.client.get_logs(request)
result = []
logs = response.get_logs() if response else []
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
logstore,
len(result),
)
return result
except LogException as e:
logger.exception(
"Failed to get logs from logstore %s with query '%s': errorCode=%s, errorMessage=%s, requestId=%s",
logstore,
query,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def execute_sql(
self,
sql: str,
logstore: str | None = None,
query: str = "*",
from_time: int | None = None,
to_time: int | None = None,
power_sql: bool = False,
) -> list[dict]:
"""
Execute SQL query for aggregation and analysis.
Args:
sql: SQL query string (SELECT statement)
logstore: Name of the logstore (required)
query: Search/filter query for SDK mode (default: "*" for all logs).
Only used in SDK mode. PG mode ignores this parameter.
from_time: Start time (Unix timestamp) - only used in SDK mode
to_time: End time (Unix timestamp) - only used in SDK mode
power_sql: Whether to use enhanced SQL mode (default: False)
Returns:
List of result rows as dictionaries
Note:
- PG mode: Only executes the SQL directly
- SDK mode: Combines query and sql as "query | sql"
"""
# Logstore is required
if not logstore:
raise ValueError("logstore parameter is required for execute_sql")
# Route to PG or SDK based on protocol availability
if self._use_pg_protocol and self._pg_client:
# PG mode: execute SQL directly (ignore query parameter)
return self._pg_client.execute_sql(sql, logstore, self.log_enabled)
else:
# SDK mode: combine query and sql as "query | sql"
full_query = f"{query} | {sql}"
# Provide default time range if not specified
if from_time is None:
from_time = 0
if to_time is None:
to_time = int(time.time()) # now
request = GetLogsRequest(
project=self.project_name,
logstore=logstore,
fromTime=from_time,
toTime=to_time,
query=full_query,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
logstore,
self.project_name,
from_time,
to_time,
query,
sql,
)
try:
response = self.client.get_logs(request)
result = []
logs = response.get_logs() if response else []
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
logstore,
len(result),
)
return result
except LogException as e:
logger.exception(
"Failed to execute SQL, logstore %s: errorCode=%s, errorMessage=%s, requestId=%s, full_query=%s",
logstore,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
full_query,
)
raise
if __name__ == "__main__":
aliyun_logstore = AliyunLogStore()
# aliyun_logstore.init_project_logstore()
aliyun_logstore.put_log(AliyunLogStore.workflow_execution_logstore, [("key1", "value1")])

View File

@ -0,0 +1,407 @@
import logging
import os
import socket
import time
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.pool
from psycopg2 import InterfaceError, OperationalError
from configs import dify_config
logger = logging.getLogger(__name__)
class AliyunLogStorePG:
"""
PostgreSQL protocol support for Aliyun SLS LogStore.
Handles PG connection pooling and operations for regions that support PG protocol.
"""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
Initialize PG connection for SLS.
Args:
access_key_id: Aliyun access key ID
access_key_secret: Aliyun access key secret
endpoint: SLS endpoint
project_name: SLS project name
"""
self._access_key_id = access_key_id
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
"""
Check if a TCP port is reachable using socket connection.
This provides a fast check before attempting full database connection,
preventing long waits when connecting to unsupported regions.
Args:
host: Hostname or IP address
port: Port number
timeout: Connection timeout in seconds (default: 2.0)
Returns:
True if port is reachable, False otherwise
"""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
result = sock.connect_ex((host, port))
sock.close()
return result == 0
except Exception as e:
logger.debug("Port connectivity check failed for %s:%d: %s", host, port, str(e))
return False
def init_connection(self) -> bool:
"""
Initialize PostgreSQL connection pool for SLS PG protocol support.
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
_use_pg_protocol to True and creates a connection pool. If connection fails
(region doesn't support PG protocol or other errors), returns False.
Returns:
True if PG protocol is supported and initialized, False otherwise
"""
try:
# Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
# Get pool configuration
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
logger.debug(
"Check PG protocol connection to SLS: host=%s, project=%s",
pg_host,
self.project_name,
)
# Fast port connectivity check before attempting full connection
# This prevents long waits when connecting to unsupported regions
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
logger.info(
"USE SDK mode for read/write operations, host=%s",
pg_host,
)
return False
# Create connection pool
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=pg_max_connections,
host=pg_host,
port=5432,
database=self.project_name,
user=self._access_key_id,
password=self._access_key_secret,
sslmode="require",
connect_timeout=5,
application_name=f"Dify-{dify_config.project.version}",
)
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
# Connection pool creation success already indicates connectivity
self._use_pg_protocol = True
logger.info(
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
self.project_name,
)
return True
except Exception as e:
# PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
if self._pg_pool:
try:
self._pg_pool.closeall()
except Exception:
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
self._pg_pool = None
logger.info(
"PG protocol connection failed (region may not support PG protocol): %s. "
"Falling back to SDK mode for read/write operations.",
str(e),
)
return False
def _is_connection_valid(self, conn: Any) -> bool:
"""
Check if a connection is still valid.
Args:
conn: psycopg2 connection object
Returns:
True if connection is valid, False otherwise
"""
try:
# Check if connection is closed
if conn.closed:
return False
# Quick ping test - execute a lightweight query
# For SLS PG protocol, we can't use SELECT 1 without FROM,
# so we just check the connection status
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Exception:
return False
@contextmanager
def _get_connection(self):
"""
Context manager to get a PostgreSQL connection from the pool.
Automatically validates and refreshes stale connections.
Note: Aliyun SLS PG protocol does not support transactions, so we always
use autocommit mode.
Yields:
psycopg2 connection object
Raises:
RuntimeError: If PG pool is not initialized
"""
if not self._pg_pool:
raise RuntimeError("PG connection pool is not initialized")
conn = self._pg_pool.getconn()
try:
# Validate connection and get a fresh one if needed
if not self._is_connection_valid(conn):
logger.debug("Connection is stale, marking as bad and getting a new one")
# Mark connection as bad and get a new one
self._pg_pool.putconn(conn, close=True)
conn = self._pg_pool.getconn()
# Aliyun SLS PG protocol does not support transactions, always use autocommit
conn.autocommit = True
yield conn
finally:
# Return connection to pool (or close if it's bad)
if self._is_connection_valid(conn):
self._pg_pool.putconn(conn)
else:
self._pg_pool.putconn(conn, close=True)
def close(self) -> None:
"""Close the PostgreSQL connection pool."""
if self._pg_pool:
try:
self._pg_pool.closeall()
logger.info("PG connection pool closed")
except Exception:
logger.exception("Failed to close PG connection pool")
def _is_retriable_error(self, error: Exception) -> bool:
"""
Check if an error is retriable (connection-related issues).
Args:
error: Exception to check
Returns:
True if the error is retriable, False otherwise
"""
# Retry on connection-related errors
if isinstance(error, (OperationalError, InterfaceError)):
return True
# Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
"timeout",
"closed",
"broken pipe",
"reset by peer",
"no route to host",
"network",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
"""
Write log to SLS using PostgreSQL protocol with automatic retry.
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
writes with log_version field for versioning, same as SDK implementation.
Args:
logstore: Name of the logstore table
contents: List of (field_name, value) tuples
log_enabled: Whether to enable logging
Raises:
psycopg2.Error: If database operation fails after all retries
"""
if not contents:
return
# Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
# Build INSERT statement with literal values
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
# so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
logger.info(
"[LogStore-PG] PUT_LOG | logstore=%s | project=%s | items_count=%d",
logstore,
self.project_name,
len(contents),
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
# Success - exit retry loop
return
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., data validation error), fail immediately
logger.exception(
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
logstore,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Last attempt failed
logger.exception(
"Failed to put logs to logstore %s via PG protocol after %d attempts",
logstore,
max_retries,
)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
"""
Execute SQL query using PostgreSQL protocol with automatic retry.
Args:
sql: SQL query string
logstore: Name of the logstore (for logging purposes)
log_enabled: Whether to enable logging
Returns:
List of result rows as dictionaries
Raises:
psycopg2.Error: If database operation fails after all retries
"""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
logstore,
self.project_name,
sql,
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
# Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
for col, val in zip(columns, row):
row_dict[col] = "" if val is None else str(val)
result.append(row_dict)
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
logstore,
len(result),
)
return result
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
logstore,
sql,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Last attempt failed
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
# This line should never be reached due to raise above, but makes type checker happy
return []

View File

@ -0,0 +1,365 @@
"""
LogStore implementation of DifyAPIWorkflowNodeExecutionRepository.
This module provides the LogStore-based implementation for service-layer
WorkflowNodeExecutionModel operations using Aliyun SLS LogStore.
"""
import logging
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
logger = logging.getLogger(__name__)
def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNodeExecutionModel:
"""
Convert LogStore result dictionary to WorkflowNodeExecutionModel instance.
Args:
data: Dictionary from LogStore query result
Returns:
WorkflowNodeExecutionModel instance (detached from session)
Note:
The returned model is not attached to any SQLAlchemy session.
Relationship fields (like offload_data) are not loaded from LogStore.
"""
logger.debug("_dict_to_workflow_node_execution_model: data keys=%s", list(data.keys())[:5])
# Create model instance without session
model = WorkflowNodeExecutionModel()
# Map all required fields with validation
# Critical fields - must not be None
model.id = data.get("id") or ""
model.tenant_id = data.get("tenant_id") or ""
model.app_id = data.get("app_id") or ""
model.workflow_id = data.get("workflow_id") or ""
model.triggered_from = data.get("triggered_from") or ""
model.node_id = data.get("node_id") or ""
model.node_type = data.get("node_type") or ""
model.status = data.get("status") or "running" # Default status if missing
model.title = data.get("title") or ""
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.index = int(data.get("index", 0))
model.elapsed_time = float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
model.predecessor_node_id = data.get("predecessor_node_id")
model.node_execution_id = data.get("node_execution_id")
model.inputs = data.get("inputs")
model.process_data = data.get("process_data")
model.outputs = data.get("outputs")
model.error = data.get("error")
model.execution_metadata = data.get("execution_metadata")
# Handle datetime fields
created_at = data.get("created_at")
if created_at:
if isinstance(created_at, str):
model.created_at = datetime.fromisoformat(created_at)
elif isinstance(created_at, (int, float)):
model.created_at = datetime.fromtimestamp(created_at)
else:
model.created_at = created_at
else:
# Provide default created_at if missing
model.created_at = datetime.now()
finished_at = data.get("finished_at")
if finished_at:
if isinstance(finished_at, str):
model.finished_at = datetime.fromisoformat(finished_at)
elif isinstance(finished_at, (int, float)):
model.finished_at = datetime.fromtimestamp(finished_at)
else:
model.finished_at = finished_at
return model
class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
"""
LogStore implementation of DifyAPIWorkflowNodeExecutionRepository.
Provides service-layer database operations for WorkflowNodeExecutionModel
using LogStore SQL queries with optimized deduplication strategies.
"""
def __init__(self, session_maker: sessionmaker | None = None):
"""
Initialize the repository with LogStore client.
Args:
session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern)
"""
logger.debug("LogstoreAPIWorkflowNodeExecutionRepository.__init__: initializing")
self.logstore_client = AliyunLogStore()
def get_node_last_execution(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
) -> WorkflowNodeExecutionModel | None:
"""
Get the most recent execution for a specific node.
Uses query syntax to get raw logs and selects the one with max log_version.
Returns the most recent execution ordered by created_at.
"""
logger.debug(
"get_node_last_execution: tenant_id=%s, app_id=%s, workflow_id=%s, node_id=%s",
tenant_id,
app_id,
workflow_id,
node_id,
)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_id = '{workflow_id}'
AND node_id = '{node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = (
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
)
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_node_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
return None
# For SDK mode, group by id and select the one with max log_version for each group
# For PG mode, this is already done by the SQL query
if not self.logstore_client.supports_pg_protocol:
id_to_results: dict[str, list[dict[str, Any]]] = {}
for row in results:
row_id = row.get("id")
if row_id:
if row_id not in id_to_results:
id_to_results[row_id] = []
id_to_results[row_id].append(row)
# For each id, select the row with max log_version
deduplicated_results = []
for rows in id_to_results.values():
if len(rows) > 1:
max_row = max(rows, key=lambda x: int(x.get("log_version", 0)))
else:
max_row = rows[0]
deduplicated_results.append(max_row)
else:
# For PG mode, results are already deduplicated by the SQL query
deduplicated_results = results
# Sort by created_at DESC and return the most recent one
deduplicated_results.sort(
key=lambda x: x.get("created_at", 0) if isinstance(x.get("created_at"), (int, float)) else 0,
reverse=True,
)
if deduplicated_results:
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
return None
except Exception:
logger.exception("Failed to get node last execution from LogStore")
raise
def get_executions_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get all node executions for a specific workflow run.
Uses query syntax to get raw logs and selects the one with max log_version for each node execution.
Ordered by index DESC for trace visualization.
"""
logger.debug(
"[LogStore] get_executions_by_workflow_run: tenant_id=%s, app_id=%s, workflow_run_id=%s",
tenant_id,
app_id,
workflow_run_id,
)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_run_id = '{workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_node_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=1000, # Get more results for node executions
reverse=False,
)
if not results:
return []
# For SDK mode, group by id and select the one with max log_version for each group
# For PG mode, this is already done by the SQL query
models = []
if not self.logstore_client.supports_pg_protocol:
id_to_results: dict[str, list[dict[str, Any]]] = {}
for row in results:
row_id = row.get("id")
if row_id:
if row_id not in id_to_results:
id_to_results[row_id] = []
id_to_results[row_id].append(row)
# For each id, select the row with max log_version
for rows in id_to_results.values():
if len(rows) > 1:
max_row = max(rows, key=lambda x: int(x.get("log_version", 0)))
else:
max_row = rows[0]
model = _dict_to_workflow_node_execution_model(max_row)
if model and model.id: # Ensure model is valid
models.append(model)
else:
# For PG mode, results are already deduplicated by the SQL query
for row in results:
model = _dict_to_workflow_node_execution_model(row)
if model and model.id: # Ensure model is valid
models.append(model)
# Sort by index DESC for trace visualization
models.sort(key=lambda x: x.index, reverse=True)
return models
except Exception:
logger.exception("Failed to get executions by workflow run from LogStore")
raise
def get_execution_by_id(
self,
execution_id: str,
tenant_id: str | None = None,
) -> WorkflowNodeExecutionModel | None:
"""
Get a workflow node execution by its ID.
Uses query syntax to get raw logs and selects the one with max log_version.
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
if tenant_id:
query = f"id: {execution_id} and tenant_id: {tenant_id}"
else:
query = f"id: {execution_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_node_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
return None
# For PG mode, result is already the latest version
# For SDK mode, if multiple results, select the one with max log_version
if self.logstore_client.supports_pg_protocol or len(results) == 1:
return _dict_to_workflow_node_execution_model(results[0])
else:
max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
return _dict_to_workflow_node_execution_model(max_result)
except Exception:
logger.exception("Failed to get execution by ID from LogStore: execution_id=%s", execution_id)
raise

View File

@ -0,0 +1,757 @@
"""
LogStore API WorkflowRun Repository Implementation
This module provides the LogStore-based implementation of the APIWorkflowRunRepository
protocol. It handles service-layer WorkflowRun database operations using Aliyun SLS LogStore
with optimized queries for statistics and pagination.
Key Features:
- LogStore SQL queries for aggregation and statistics
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
"""
import logging
import os
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.types import (
AverageInteractionStats,
DailyRunsStats,
DailyTerminalsStats,
DailyTokenCostStats,
)
logger = logging.getLogger(__name__)
def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
"""
Convert LogStore result dictionary to WorkflowRun instance.
Args:
data: Dictionary from LogStore query result
Returns:
WorkflowRun instance
"""
logger.debug("_dict_to_workflow_run: data keys=%s", list(data.keys())[:5])
# Create model instance without session
model = WorkflowRun()
# Map all required fields with validation
# Critical fields - must not be None
model.id = data.get("id") or ""
model.tenant_id = data.get("tenant_id") or ""
model.app_id = data.get("app_id") or ""
model.workflow_id = data.get("workflow_id") or ""
model.type = data.get("type") or ""
model.triggered_from = data.get("triggered_from") or ""
model.version = data.get("version") or ""
model.status = data.get("status") or "running" # Default status if missing
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.total_tokens = int(data.get("total_tokens", 0))
model.total_steps = int(data.get("total_steps", 0))
model.exceptions_count = int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
model.inputs = data.get("inputs")
model.outputs = data.get("outputs")
model.error = data.get("error_message") or data.get("error")
# Handle datetime fields
started_at = data.get("started_at") or data.get("created_at")
if started_at:
if isinstance(started_at, str):
model.created_at = datetime.fromisoformat(started_at)
elif isinstance(started_at, (int, float)):
model.created_at = datetime.fromtimestamp(started_at)
else:
model.created_at = started_at
else:
# Provide default created_at if missing
model.created_at = datetime.now()
finished_at = data.get("finished_at")
if finished_at:
if isinstance(finished_at, str):
model.finished_at = datetime.fromisoformat(finished_at)
elif isinstance(finished_at, (int, float)):
model.finished_at = datetime.fromtimestamp(finished_at)
else:
model.finished_at = finished_at
# Compute elapsed_time from started_at and finished_at
# LogStore doesn't store elapsed_time, it's computed in WorkflowExecution domain entity
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
model.elapsed_time = float(data.get("elapsed_time", 0))
return model
class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
"""
LogStore implementation of APIWorkflowRunRepository.
Provides service-layer WorkflowRun database operations using LogStore SQL
with optimized query strategies:
- Use finished_at IS NOT NULL for deduplication (10-100x faster)
- Use window functions only when running status is required
- Proper time range filtering for LogStore queries
"""
def __init__(self, session_maker: sessionmaker | None = None):
"""
Initialize the repository with LogStore client.
Args:
session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern)
"""
logger.debug("LogstoreAPIWorkflowRunRepository.__init__: initializing")
self.logstore_client = AliyunLogStore()
# Control flag for dual-read (fallback to PostgreSQL when LogStore returns no results)
# Set to True to enable fallback for safe migration from PostgreSQL to LogStore
# Set to False for new deployments without legacy data in PostgreSQL
self._enable_dual_read = os.environ.get("LOGSTORE_DUAL_READ_ENABLED", "true").lower() == "true"
def get_paginated_workflow_runs(
self,
tenant_id: str,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom | Sequence[WorkflowRunTriggeredFrom],
limit: int = 20,
last_id: str | None = None,
status: str | None = None,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs with filtering.
Uses window function for deduplication to support both running and finished states.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
app_id: Application identifier
triggered_from: Filter by trigger source(s)
limit: Maximum number of records to return (default: 20)
last_id: Cursor for pagination - ID of the last record from previous page
status: Optional filter by status
Returns:
InfiniteScrollPagination object
"""
logger.debug(
"get_paginated_workflow_runs: tenant_id=%s, app_id=%s, limit=%d, status=%s",
tenant_id,
app_id,
limit,
status,
)
# Convert triggered_from to list if needed
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
# Build triggered_from filter
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
# Build status filter
status_filter = f"AND status='{status}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
last_id_filter = ""
if last_id:
# TODO: Implement proper cursor-based pagination with created_at
logger.warning("last_id pagination not fully implemented for LogStore")
# Use window function to get latest log_version of each workflow run
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
) t
WHERE rn = 1
ORDER BY created_at DESC
LIMIT {limit + 1}
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore, from_time=None, to_time=None
)
# Check if there are more records
has_more = len(results) > limit
if has_more:
results = results[:limit]
# Convert results to WorkflowRun models
workflow_runs = [_dict_to_workflow_run(row) for row in results]
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
except Exception:
logger.exception("Failed to get paginated workflow runs from LogStore")
raise
def get_workflow_run_by_id(
self,
tenant_id: str,
app_id: str,
run_id: str,
) -> WorkflowRun | None:
"""
Get a specific workflow run by ID with tenant and app isolation.
Uses query syntax to get raw logs and selects the one with max log_version in code.
Falls back to PostgreSQL if not found in LogStore (for data consistency during migration).
"""
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
# Fallback to PostgreSQL for records created before LogStore migration
if self._enable_dual_read:
logger.debug(
"WorkflowRun not found in LogStore, falling back to PostgreSQL: "
"run_id=%s, tenant_id=%s, app_id=%s",
run_id,
tenant_id,
app_id,
)
return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id)
return None
# For PG mode, results are already deduplicated by the SQL query
# For SDK mode, if multiple results, select the one with max log_version
if self.logstore_client.supports_pg_protocol or len(results) == 1:
return _dict_to_workflow_run(results[0])
else:
max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
return _dict_to_workflow_run(max_result)
except Exception:
logger.exception("Failed to get workflow run by ID from LogStore: run_id=%s", run_id)
# Try PostgreSQL fallback on any error (only if dual-read is enabled)
if self._enable_dual_read:
try:
return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id)
except Exception:
logger.exception(
"PostgreSQL fallback also failed: run_id=%s, tenant_id=%s, app_id=%s", run_id, tenant_id, app_id
)
raise
def _fallback_get_workflow_run_by_id_with_tenant(
self, run_id: str, tenant_id: str, app_id: str
) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
with Session(db.engine) as session:
stmt = select(WorkflowRun).where(
WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
)
return session.scalar(stmt)
def get_workflow_run_by_id_without_tenant(
self,
run_id: str,
) -> WorkflowRun | None:
"""
Get a specific workflow run by ID without tenant/app context.
Uses query syntax to get raw logs and selects the one with max log_version.
Falls back to PostgreSQL if not found in LogStore (controlled by LOGSTORE_DUAL_READ_ENABLED).
"""
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
# Fallback to PostgreSQL for records created before LogStore migration
if self._enable_dual_read:
logger.debug("WorkflowRun not found in LogStore, falling back to PostgreSQL: run_id=%s", run_id)
return self._fallback_get_workflow_run_by_id(run_id)
return None
# For PG mode, results are already deduplicated by the SQL query
# For SDK mode, if multiple results, select the one with max log_version
if self.logstore_client.supports_pg_protocol or len(results) == 1:
return _dict_to_workflow_run(results[0])
else:
max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
return _dict_to_workflow_run(max_result)
except Exception:
logger.exception("Failed to get workflow run without tenant: run_id=%s", run_id)
# Try PostgreSQL fallback on any error (only if dual-read is enabled)
if self._enable_dual_read:
try:
return self._fallback_get_workflow_run_by_id(run_id)
except Exception:
logger.exception("PostgreSQL fallback also failed: run_id=%s", run_id)
raise
def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore."""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
with Session(db.engine) as session:
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
return session.scalar(stmt)
def get_workflow_runs_count(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
status: str | None = None,
time_range: str | None = None,
) -> dict[str, int]:
"""
Get workflow runs count statistics grouped by status.
Optimization: Use finished_at IS NOT NULL for completed runs (10-50x faster)
"""
logger.debug(
"get_workflow_runs_count: tenant_id=%s, app_id=%s, triggered_from=%s, status=%s",
tenant_id,
app_id,
triggered_from,
status,
)
# Build time range filter
time_filter = ""
if time_range:
# TODO: Parse time_range and convert to from_time/to_time
logger.warning("time_range filter not implemented")
# If status is provided, simple count
if status:
if status == "running":
# Running status requires window function
sql = f"""
SELECT COUNT(*) as count
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='running'
{time_filter}
) t
WHERE rn = 1
"""
else:
# Finished status uses optimized filter
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='{status}'
AND finished_at IS NOT NULL
{time_filter}
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
count = results[0]["count"] if results and len(results) > 0 else 0
return {
"total": count,
"running": count if status == "running" else 0,
"succeeded": count if status == "succeeded" else 0,
"failed": count if status == "failed" else 0,
"stopped": count if status == "stopped" else 0,
"partial-succeeded": count if status == "partial-succeeded" else 0,
}
except Exception:
logger.exception("Failed to get workflow runs count")
raise
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
"""
# Count running runs
running_sql = f"""
SELECT COUNT(*) as count
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='running'
{time_filter}
) t
WHERE rn = 1
"""
finished_results = self.logstore_client.execute_sql(
sql=finished_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
running_results = self.logstore_client.execute_sql(
sql=running_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
# Build response
status_counts = {
"running": 0,
"succeeded": 0,
"failed": 0,
"stopped": 0,
"partial-succeeded": 0,
}
total = 0
for result in finished_results:
status_val = result.get("status")
count = result.get("count", 0)
if status_val in status_counts:
status_counts[status_val] = count
total += count
# Add running count
running_count = running_results[0]["count"] if running_results and len(running_results) > 0 else 0
status_counts["running"] = running_count
total += running_count
return {"total": total} | status_counts
except Exception:
logger.exception("Failed to get workflow runs count")
raise
def get_daily_runs_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[DailyRunsStats]:
"""
Get daily runs statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT id) (20-100x faster)
"""
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
# Optimized query: Use finished_at filter to avoid window function
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
ORDER BY date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append({"date": str(row.get("date", "")), "runs": row.get("runs", 0)})
return cast(list[DailyRunsStats], response_data)
except Exception:
logger.exception("Failed to get daily runs statistics")
raise
def get_daily_terminals_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[DailyTerminalsStats]:
"""
Get daily terminals statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT created_by) (20-100x faster)
"""
logger.debug(
"get_daily_terminals_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
tenant_id,
app_id,
triggered_from,
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
ORDER BY date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append({"date": str(row.get("date", "")), "terminal_count": row.get("terminal_count", 0)})
return cast(list[DailyTerminalsStats], response_data)
except Exception:
logger.exception("Failed to get daily terminals statistics")
raise
def get_daily_token_cost_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[DailyTokenCostStats]:
"""
Get daily token cost statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + SUM(total_tokens) (20-100x faster)
"""
logger.debug(
"get_daily_token_cost_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
tenant_id,
app_id,
triggered_from,
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
ORDER BY date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append({"date": str(row.get("date", "")), "token_count": row.get("token_count", 0)})
return cast(list[DailyTokenCostStats], response_data)
except Exception:
logger.exception("Failed to get daily token cost statistics")
raise
def get_average_app_interaction_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[AverageInteractionStats]:
"""
Get average app interaction statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + AVG (20-100x faster)
"""
logger.debug(
"get_average_app_interaction_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
tenant_id,
app_id,
triggered_from,
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
sql = f"""
SELECT
AVG(sub.interactions) AS interactions,
sub.date
FROM (
SELECT
DATE(from_unixtime(__time__)) AS date,
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by
) sub
GROUP BY sub.date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append(
{
"date": str(row.get("date", "")),
"interactions": float(row.get("interactions", 0)),
}
)
return cast(list[AverageInteractionStats], response_data)
except Exception:
logger.exception("Failed to get average app interaction statistics")
raise

View File

@ -0,0 +1,164 @@
import json
import logging
import os
import time
from typing import Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
Account,
CreatorUserRole,
EndUser,
)
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
logger.debug(
"LogstoreWorkflowExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from
)
# Initialize LogStore client
# Note: Project/logstore/index initialization is done at app startup via ext_logstore
self.logstore_client = AliyunLogStore()
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize SQL repository for dual-write support
self.sql_repository = SQLAlchemyWorkflowExecutionRepository(session_factory, user, app_id, triggered_from)
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
"""
Convert a domain model to a logstore model (List[Tuple[str, str]]).
Args:
domain_model: The domain model to convert
Returns:
The logstore model as a list of key-value tuples
"""
logger.debug(
"_to_logstore_model: id=%s, workflow_id=%s, status=%s",
domain_model.id_,
domain_model.workflow_id,
domain_model.status.value,
)
# Use values from constructor if provided
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
("tenant_id", self._tenant_id),
("app_id", self._app_id or ""),
("workflow_id", domain_model.workflow_id),
(
"triggered_from",
self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from),
),
("type", domain_model.workflow_type.value),
("version", domain_model.workflow_version),
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
("status", domain_model.status.value),
("error_message", domain_model.error_message or ""),
("total_tokens", str(domain_model.total_tokens)),
("total_steps", str(domain_model.total_steps)),
("exceptions_count", str(domain_model.exceptions_count)),
(
"created_by_role",
self._creator_user_role.value
if hasattr(self._creator_user_role, "value")
else str(self._creator_user_role),
),
("created_by", self._creator_user_id),
("started_at", domain_model.started_at.isoformat() if domain_model.started_at else ""),
("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""),
]
return logstore_model
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution domain entity to the logstore.
This method serves as a domain-to-logstore adapter that:
1. Converts the domain entity to its logstore representation
2. Persists the logstore model using Aliyun SLS
3. Maintains proper multi-tenancy by including tenant context during conversion
4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED)
Args:
execution: The WorkflowExecution domain entity to persist
"""
logger.debug(
"save: id=%s, workflow_id=%s, status=%s", execution.id_, execution.workflow_id, execution.status.value
)
try:
logstore_model = self._to_logstore_model(execution)
self.logstore_client.put_log(AliyunLogStore.workflow_execution_logstore, logstore_model)
logger.debug("Saved workflow execution to logstore: id=%s", execution.id_)
except Exception:
logger.exception("Failed to save workflow execution to logstore: id=%s", execution.id_)
raise
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save(execution)
logger.debug("Dual-write: saved workflow execution to SQL database: id=%s", execution.id_)
except Exception:
logger.exception("Failed to dual-write workflow execution to SQL database: id=%s", execution.id_)
# Don't raise - LogStore write succeeded, SQL is just a backup

Some files were not shown because too many files have changed in this diff Show More