Compare commits

..

372 Commits

Author SHA1 Message Date
887de73c78 Merge branch 'feat/snippet-fe' into deploy/dev 2026-06-01 13:47:09 +08:00
bccc948031 Merge branch 'main' into feat/snippet-fe 2026-06-01 12:56:52 +08:00
3c1bc6ac36 fix(web): checkValid of snippet 2026-06-01 12:55:57 +08:00
00cf21a1b7 Merge branch 'main' into feat/snippet-fe 2026-06-01 10:48:01 +08:00
705d317f57 merge main 2026-05-29 17:04:05 +08:00
e0158fe9fd fix(web): snippet draft sync 2026-05-29 16:47:48 +08:00
48e6902f5f fix(web): create snippet from workflow 2026-05-29 15:15:43 +08:00
efe98b1e52 fix(web): create snippet from workflow 2026-05-29 14:49:00 +08:00
a62c616664 fix(web): show draft in snippet in default 2026-05-29 14:29:39 +08:00
6ed691a2c9 feat: validate snippet graph for forbidden nodes during creation 2026-05-29 14:05:49 +08:00
41ba73835d fix(web): create snippet from workflow 2026-05-29 13:47:14 +08:00
55c1d1d4be Merge branch 'feat/snippet-new' into deploy/dev 2026-05-29 10:56:07 +08:00
59adfffbb4 feat: allow duplicate snippet names and enhance snippet import handling 2026-05-29 10:37:13 +08:00
ed442771f4 Merge branch 'main' into feat/snippet-fe 2026-05-29 10:26:41 +08:00
5730ede96f feat: dev snippet fronted (#36785)
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: EvanYao826 <155432245+EvanYao826@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: 盐粒 Yanli <yanli@dify.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Tianle <40735546+Tianlel@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
Co-authored-by: zyssyz123 <916125788@qq.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: chariri <w@chariri.moe>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Nian <11332799+Lillian68@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Carmen Fernández Ruiz <279459669+zeus1959@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: L1nSn0w <l1nsn0w@qq.com>
Co-authored-by: Evan <2869018789@qq.com>
Co-authored-by: Escape0707 <tothesong@gmail.com>
Co-authored-by: Jingyi <jingyi.qi@dify.ai>
Co-authored-by: Amr Sherif <140330826+amr-sheriff@users.noreply.github.com>
Co-authored-by: ZHOU ZHICHEN <118870511+zhuiguangzhe2003@users.noreply.github.com>
Co-authored-by: unknown <EI05187@apwx.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-28 18:01:08 +08:00
107bba0116 feat: dev snippet fronted (#36784)
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: EvanYao826 <155432245+EvanYao826@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: 盐粒 Yanli <yanli@dify.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Tianle <40735546+Tianlel@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
Co-authored-by: zyssyz123 <916125788@qq.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: chariri <w@chariri.moe>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Nian <11332799+Lillian68@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Carmen Fernández Ruiz <279459669+zeus1959@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: L1nSn0w <l1nsn0w@qq.com>
Co-authored-by: Evan <2869018789@qq.com>
Co-authored-by: Escape0707 <tothesong@gmail.com>
Co-authored-by: Jingyi <jingyi.qi@dify.ai>
Co-authored-by: Amr Sherif <140330826+amr-sheriff@users.noreply.github.com>
Co-authored-by: ZHOU ZHICHEN <118870511+zhuiguangzhe2003@users.noreply.github.com>
Co-authored-by: unknown <EI05187@apwx.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-28 17:34:14 +08:00
c46a313d78 Merge branch 'main' into feat/snippet-fe 2026-05-28 14:31:54 +08:00
e1fec86a2a Merge branch 'main' into feat/evaluation-fe 2026-05-28 13:11:45 +08:00
6ec893cb0e feat(web): new interactions in snippet detail graph 2026-05-28 13:11:07 +08:00
8be34ee000 feat(web): support create snippet by import DSL 2026-05-28 11:42:57 +08:00
458f669883 Merge branch 'main' into feat/evaluation-fe 2026-05-28 11:04:59 +08:00
94fd4e9c67 Merge branch 'main' into feat/evaluation-fe 2026-05-27 22:34:15 +08:00
f5da3ce499 fix(web): refresh list after snippet created 2026-05-27 22:33:45 +08:00
08f2971f72 fix(web): snippet draft sync 2026-05-27 22:29:49 +08:00
54ac42fbc4 Merge branch 'main' into feat/evaluation-fe 2026-05-27 22:08:12 +08:00
1d1d571213 fix(web): error message of snippet creation 2026-05-27 22:07:06 +08:00
056caa8b2f fix(web): style of tag filter of snippets 2026-05-27 22:00:05 +08:00
1cf6cdb764 fix(web): stop snippet run 2026-05-27 21:39:14 +08:00
ac9083fbf1 fix(web): tooltip of snippet saved 2026-05-27 21:24:07 +08:00
fdfc9ab3d3 fix(web): sys variables not supported in snippet 2026-05-27 21:13:19 +08:00
83cd1a8d7a fix(web): snippet restore 2026-05-27 21:06:36 +08:00
a3dfd670b0 fix(web): sys variables not supported in snippet 2026-05-27 21:03:00 +08:00
facace019b fix(web): merge error fix 2026-05-27 20:58:03 +08:00
fd9543868d Merge branch 'main' into feat/evaluation-fe 2026-05-27 20:33:59 +08:00
0e6cb87f08 fix(workflow): keep pointer position out of reactive store 2026-05-27 17:30:02 +08:00
ef9c607f04 fix(api): raise BadRequest when retrieving parameters for disabled web app 2026-05-27 15:49:32 +08:00
9d082489c9 feat: snippet new (#36684) 2026-05-26 17:24:17 +08:00
81553d8813 feat: add function to store node inputs under aliases in variable pool 2026-05-26 17:21:03 +08:00
c550d6b085 feat: snippet new (#36682) 2026-05-26 16:40:15 +08:00
99167ace74 feat: enhance snippet workflow handling with start node injection 2026-05-26 16:37:00 +08:00
db5d5bfffe Merge branch 'p428' into deploy/dev 2026-05-26 09:52:44 +08:00
0f2cbc2968 fix: when exclude_vector_space should feature.vector_space not has default value 2026-05-26 09:50:08 +08:00
892387ea38 fix merge error 2026-05-26 09:34:04 +08:00
2f351641e4 Merge remote-tracking branch 'myori/main' into p428 2026-05-26 09:32:21 +08:00
36a51dca8b chore: backend feature api exclude_vector_space 2026-05-26 09:30:17 +08:00
93d9423c95 fix: member invite limits with dedup, locking, and accurate new-member counting (#36512) 2026-05-25 17:18:53 +08:00
3cabe9058b fix : features get null vector_space 2026-05-25 17:03:46 +08:00
d925ed2f28 feat: snippet new (#36617) 2026-05-25 16:36:35 +08:00
8cca26010c feat: add support for excluding node IDs in variable listing 2026-05-25 16:31:52 +08:00
d859728dd7 Merge branch 'feat/evaluation-fe' into deploy/dev 2026-05-25 15:45:26 +08:00
89188256e1 Merge branch 'main' into feat/evaluation-fe 2026-05-25 15:44:50 +08:00
bba3a1bcee fix(web): draft sync of snippet 2026-05-25 15:44:27 +08:00
d66bfc7434 Merge branch 'feat/evaluation-fe' into deploy/dev 2026-05-25 14:37:07 +08:00
7c0be7f905 Merge branch 'main' into feat/evaluation-fe 2026-05-25 14:35:41 +08:00
599e3475f2 fix(web): hide snippets tab in block-selector in snippet detail 2026-05-25 14:35:24 +08:00
718fe548e9 fix(web): cancel button in snippet 2026-05-25 14:30:39 +08:00
060ceaffd1 feat(web): snippet info siderbar 2026-05-25 14:18:15 +08:00
00908ca0fb Merge branch 'main' into feat/evaluation-fe 2026-05-25 13:36:05 +08:00
2812d61e24 feat(web): snippet card style 2026-05-25 13:35:59 +08:00
8adcac87a5 fix : features get null vector_space 2026-05-25 12:04:29 +08:00
544d8567c9 Merge branch 'feat/evaluation-fe' into deploy/dev 2026-05-25 11:41:33 +08:00
be1d6520f9 fix(web): icon_info nullable in dataset 2026-05-25 11:41:02 +08:00
eeb1cd19bd Merge branch 'feat/evaluation-fe' into deploy/dev 2026-05-25 11:06:03 +08:00
7fb2e4751f Merge branch 'main' into feat/evaluation-fe 2026-05-25 11:05:26 +08:00
e4620b4b22 feat: new snippet (#36597) 2026-05-25 10:25:13 +08:00
8af1766081 Merge remote-tracking branch 'origin/main' 2026-05-25 10:12:24 +08:00
5441992604 Merge branch 'main' into feat/evaluation-fe 2026-05-25 10:00:09 +08:00
9d0597c22d feat(web): snippet layout update 2026-05-23 11:06:57 +08:00
5d489ab92d chore(web): remove snippet plan guard 2026-05-23 10:28:22 +08:00
930da499d1 feat(web): add snippet 2026-05-23 10:20:10 +08:00
f1527ef7c1 feat(web): create snippet from workflow 2026-05-23 10:01:25 +08:00
20f89b6e90 feat(web): operations of snippet card 2026-05-23 09:12:43 +08:00
05e69b104a fix(web): label of snippet creation modal 2026-05-23 08:59:46 +08:00
f39b1b6731 Merge branch 'main' into feat/evaluation-fe 2026-05-23 08:54:10 +08:00
a7005efab3 chore(web): remove icon info of snippet 2026-05-22 18:27:37 +08:00
f605288429 feat(web): snippets header nav 2026-05-22 18:13:27 +08:00
2bb3b439e0 refactor(web): snippet list page 2026-05-22 18:06:29 +08:00
75daf8e61b feat(web): update filters in apps 2026-05-22 17:43:24 +08:00
bf30b11d0d merge main 2026-05-22 16:12:02 +08:00
20e0b329d3 feat: snippet supports tags. 2026-05-22 13:27:46 +08:00
778e472173 merge main 2026-05-22 10:23:30 +08:00
31e2e5d01b feat: snippet supports tags. 2026-05-22 10:20:57 +08:00
8f9e2a895a feat: lite snippet. 2026-05-21 17:23:25 +08:00
b91de7e54b Merge remote-tracking branch 'origin/main' 2026-05-20 18:44:07 +08:00
2885ba8519 Merge branch 'main' into feat/evaluation-fe 2026-05-20 11:32:08 +08:00
e23c3d1491 chore(web): remove evaluation frontend 2026-05-19 17:39:28 +08:00
888292564b merge main 2026-05-19 13:58:23 +08:00
1a0c8f6173 fix(api): avoid committing inside RAG pipeline DSL service. 2026-05-14 16:54:50 +08:00
d8851a4994 fix(api): avoid committing inside RAG pipeline DSL service. 2026-05-14 11:36:02 +08:00
8a21679ea8 Merge remote-tracking branch 'origin/main' 2026-05-14 11:00:27 +08:00
b8a594def0 fix: add unit tests. 2026-05-13 15:18:02 +08:00
69a77ad9ce Merge remote-tracking branch 'origin/fix/issue-36090' 2026-05-13 14:56:33 +08:00
93728bb39f Merge remote-tracking branch 'origin/main' 2026-05-13 14:54:53 +08:00
c4da7a0bed fix: add unit tests. 2026-05-13 14:54:15 +08:00
05fd412670 Merge branch 'main' into fix/issue-36090 2026-05-13 14:34:17 +08:00
a4821288cc fix: When hit-testing, an empty document dict is returned due to DocumentSegment type modification. 2026-05-13 14:33:33 +08:00
fc0a4a6b56 fix: When hit-testing, an empty document dict is returned due to DocumentSegment type modification. 2026-05-13 14:06:27 +08:00
0a3bb67778 Merge remote-tracking branch 'origin/main' 2026-05-13 10:41:30 +08:00
5e9f419154 fix: For core.tools.signature.sign_upload_file function, the file URL is generated for external preview or download only. 2026-05-09 16:25:44 +08:00
6b84383590 Merge remote-tracking branch 'origin/fix/issue-35910' 2026-05-09 10:57:42 +08:00
d7f99d6458 fix: Using CONSOLE_API_URL to generate an image preview URL causes the image preview to fail on the chunks details page of the knowledge base. 2026-05-09 10:55:59 +08:00
6c80ee8f48 Merge branch 'main' into fix/issue-35910 2026-05-09 10:28:37 +08:00
ea71990388 fix: Using CONSOLE_API_URL to generate an image preview URL causes the image preview to fail on the chunks details page of the knowledge base. 2026-05-09 10:09:54 +08:00
36e8677b1a Merge remote-tracking branch 'origin/fix/issue-35910' 2026-05-08 14:24:37 +08:00
5c31a774ea Merge remote-tracking branch 'origin/main' 2026-05-08 14:22:14 +08:00
9e137e12ab fix: Image rendering in the knowledge base failed. 2026-05-08 14:16:25 +08:00
18e2ecd6c5 Merge branch 'main' into fix/issue-35910 2026-05-08 14:02:27 +08:00
8a23126f29 fix: Image rendering in the knowledge base failed. 2026-05-08 13:55:05 +08:00
6c5f6699d2 Merge remote-tracking branch 'origin/main'
# Conflicts:
#	api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py
2026-05-08 11:32:47 +08:00
124b786dfb fix(web): fix app list 2026-05-05 21:34:55 +08:00
dd54ca0cab feat(web): fix template 2026-05-05 20:41:53 +08:00
8a72e46ce8 feat(web): use api to fetch template 2026-05-05 20:27:35 +08:00
f00f8e020f Merge branch 'main' into jzh 2026-05-05 19:55:21 +08:00
aa078a854c fix(web): evaluation detail in workflow log 2026-04-30 17:19:15 +08:00
712aae4d98 fix(web): workflow log crash 2026-04-30 17:06:56 +08:00
bacadc4d35 Merge branch 'main' into jzh 2026-04-30 16:51:33 +08:00
b060e81824 fix(web): test history of batch test 2026-04-30 16:51:16 +08:00
b45f83492e Merge branch 'main' into jzh 2026-04-30 15:31:04 +08:00
d1e1a4a8ab fix(web): correct templates 2026-04-30 15:30:37 +08:00
4519847e81 Merge branch 'main' into jzh 2026-04-30 14:19:21 +08:00
3763efbc7c fix(web): selected workflow name fix 2026-04-30 14:18:32 +08:00
552f202ca8 Merge branch 'main' into jzh 2026-04-30 13:00:41 +08:00
dc76f4082f fix(web): i18n of switch modal 2026-04-30 12:53:52 +08:00
6d01095586 fix(web): only workflow can switch 2026-04-30 12:41:37 +08:00
b914e48a41 fix(web): add toast for workflow switch 2026-04-30 12:30:22 +08:00
da482ec455 fix(web): custom metric display 2026-04-30 12:20:56 +08:00
48c38ace54 fix(web): rag publisher 2026-04-30 11:02:23 +08:00
2b1496c857 Merge branch 'main' into jzh 2026-04-30 08:54:42 +08:00
c15e437ff7 Merge branch 'main' into jzh 2026-04-30 00:07:30 +08:00
0ac0eccce4 fix(web): correct template of dataset evaluation 2026-04-30 00:06:59 +08:00
678327e994 fix(wbe): style of switch button 2026-04-29 22:48:38 +08:00
b0478f4df7 fix(web): workflow switch to evaluation 2026-04-29 22:25:47 +08:00
00319f0e43 fix(web): snippet run logs 2026-04-29 21:25:50 +08:00
55eb894d8e fix(web): scroll of dataset evaluation page 2026-04-29 20:48:39 +08:00
c59a80a41f fix(web): test detail of dataset 2026-04-29 20:38:54 +08:00
24b482893d fix(web): pipeline batch test template 2026-04-29 20:36:07 +08:00
ad58895b25 fix(web): template of dataset evaluation template 2026-04-29 20:30:43 +08:00
25fc518c5d fix(web): style of config 2026-04-29 20:28:41 +08:00
d92722e7ab fix(web): default metrics for dataset 2026-04-29 20:16:42 +08:00
4041fd7e5c fix(web): auto select model in evaluation 2026-04-29 20:10:37 +08:00
06ea73a19b Merge branch 'main' into jzh 2026-04-29 18:40:11 +08:00
7384a3c121 fix(web): template generate 2026-04-29 18:05:34 +08:00
c18c953a7c fix(web): style of batch test 2026-04-29 17:46:29 +08:00
ae2df0c35e fix(web): style of batch test 2026-04-29 17:28:53 +08:00
dacc7fc740 fix(web): style of conditions 2026-04-29 17:15:32 +08:00
9af2c1252c fix(web): remove node 2026-04-29 16:17:31 +08:00
35bfe26a3a fix(web): default metrics 2026-04-29 16:07:42 +08:00
8686362aeb Merge branch 'main' into jzh 2026-04-29 16:06:27 +08:00
f5955489ec Merge branch 'main' into jzh 2026-04-29 15:34:27 +08:00
aaa15770d5 fix(web): metric descriptions 2026-04-29 15:34:00 +08:00
08c01c4f3f chore(web): remove mock data of evaluation 2026-04-29 15:28:10 +08:00
0903c30060 fix(web): metric icon & tooltips 2026-04-29 15:10:51 +08:00
b420298398 fix(web): style of default metric 2026-04-29 14:31:23 +08:00
2607eb8d32 feat(web): default metrics 2026-04-29 13:56:33 +08:00
d8173b1cda feat(web): add reset button 2026-04-29 13:32:46 +08:00
c56f1a8216 fix(web): detail info in snippet evaluation 2026-04-29 13:11:09 +08:00
31e74371ef Merge branch 'main' into jzh 2026-04-29 12:06:29 +08:00
e48f13f173 Merge branch 'main' into jzh 2026-04-28 16:34:20 +08:00
c574363cf6 chore(web): remove unused data 2026-04-28 16:34:00 +08:00
70fd4a5c88 feat(web): support snippet nav 2026-04-28 16:26:39 +08:00
e62a67c719 fix: hit-testing response failed because of Pydantic check. 2026-04-28 16:22:27 +08:00
57c1195253 Merge remote-tracking branch 'origin/main' 2026-04-28 16:19:14 +08:00
42889d23e5 fix(web): snippet card 2026-04-28 16:05:26 +08:00
3a7f09a250 fix(web): snippet publish check 2026-04-28 15:42:31 +08:00
d95d4335bf fix(web): snippet can use input fields 2026-04-28 15:30:35 +08:00
735e88f673 fix(web): snippet usage count 2026-04-28 15:03:03 +08:00
c55105bff3 fix(web): snippet add 2026-04-28 14:58:43 +08:00
77afc805e1 fix(web): snippet draft sync 2026-04-28 14:42:51 +08:00
9dd73b4d47 Merge branch 'main' into jzh 2026-04-28 11:02:44 +08:00
f2b12bfef7 Merge branch 'main' into jzh 2026-04-27 15:49:05 +08:00
dbeaf79d77 fix(web): snippet init 2026-04-27 15:45:13 +08:00
63dcb4dd6c refactor(web): remove mock data of snippet detail 2026-04-27 15:01:57 +08:00
9df3a7bcf9 Merge branch 'main' into jzh 2026-04-27 14:47:53 +08:00
89163edd16 Merge branch 'main' into jzh 2026-04-27 13:48:47 +08:00
eaa55aab1e Merge branch 'main' into jzh 2026-04-27 11:45:30 +08:00
8d3a690c0a Merge branch 'main' into jzh 2026-04-24 18:22:52 +08:00
5263a65ed6 Merge branch 'main' into jzh 2026-04-24 17:52:27 +08:00
24d3e8edba Merge branch 'main' into jzh 2026-04-23 16:18:36 +08:00
b371dd2cdf Merge branch 'main' into jzh 2026-04-22 20:44:13 +08:00
597ad8c425 Merge branch 'main' into jzh 2026-04-21 21:01:05 +08:00
33f9d96caa Merge branch 'main' into jzh 2026-04-21 11:06:45 +08:00
689571df22 Merge branch 'main' into jzh 2026-04-20 18:01:17 +08:00
a3242f0634 Merge branch 'main' into jzh 2026-04-20 16:42:16 +08:00
f5112928b3 fix(web): snippet graph view port 2026-04-20 16:39:38 +08:00
bcd87ddc58 fix: publish as evaluation 2026-04-20 16:24:53 +08:00
7c8a87af05 Merge branch 'main' into jzh 2026-04-20 16:11:07 +08:00
8e2d507e5c fix(web): fix selection menu hide 2026-04-20 16:08:58 +08:00
b6fbec066d fix(web): nav icons 2026-04-20 15:47:34 +08:00
bd136cadce fix: evaluation switch button 2026-04-20 14:37:13 +08:00
0a934e1143 fix merge error 2026-04-20 13:46:31 +08:00
c44ba62da3 Merge branch 'main' into jzh 2026-04-20 12:06:18 +08:00
76c0aed05c fix merge error 2026-04-20 11:58:17 +08:00
e7fc22c6b3 Merge branch 'main' into jzh 2026-04-20 10:24:40 +08:00
b91727b804 Merge branch 'main' into jzh 2026-04-20 10:23:09 +08:00
534fd79377 Merge branch 'main' into jzh 2026-04-17 16:46:10 +08:00
3ea4742b29 Merge branch 'main' into jzh 2026-04-17 16:26:19 +08:00
364c0eb6e2 Merge branch 'main' into jzh 2026-04-17 14:57:13 +08:00
322b3ff641 fix(web): fix merge error 2026-04-17 14:49:08 +08:00
38736c154b Merge branch 'main' into jzh 2026-04-17 14:05:23 +08:00
129f681c59 fix(web): slient snippet draft fetching 2026-04-16 12:29:07 +08:00
d776fc0827 fix(web): icon missing 2026-04-16 12:18:48 +08:00
7af6074cb5 Merge branch 'main' into jzh 2026-04-16 12:09:59 +08:00
7aa700bf2b fix(web): fix merge 2026-04-16 12:06:53 +08:00
0d47750b15 Merge branch 'main' into jzh 2026-04-16 11:48:48 +08:00
a9dc57eeef fix(web): input fields form & graph publish 2026-04-16 11:07:26 +08:00
5bfebd371d fix(web): snippet draft sync 2026-04-16 10:31:35 +08:00
f1da2c76d1 fix(web): add page title for snippet 2026-04-15 18:22:20 +08:00
b5dc774093 feat(web): empty list of snippet 2026-04-15 16:53:37 +08:00
b7fe45d800 Merge branch 'main' into jzh 2026-04-15 15:57:48 +08:00
7f5bbe0ee3 fix(web): button import 2026-04-15 15:23:50 +08:00
40632589a2 Merge branch 'main' into jzh 2026-04-15 15:01:33 +08:00
e6e063138e Merge branch 'main' into jzh 2026-04-15 10:20:27 +08:00
605af8d60e Merge branch 'main' into jzh 2026-04-14 14:05:44 +08:00
03660c19ef Merge remote-tracking branch 'origin/main'
# Conflicts:
#	api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py
2026-04-14 13:54:31 +08:00
8747e3a2d3 Merge branch 'main' into jzh 2026-04-14 10:14:13 +08:00
7fd549fd39 fix: Compatibility issues with the summary index feature when using the weaviate vector database. 2026-04-13 18:44:53 +08:00
1712a2732a Merge branch 'main' into jzh 2026-04-13 17:22:27 +08:00
46bc76bae3 feat(web): add evaluation cell in workflow log 2026-04-13 17:22:04 +08:00
e24b6c27b0 Merge remote-tracking branch 'origin/main' 2026-04-13 14:29:52 +08:00
8c6dda125f Merge branch 'main' into jzh 2026-04-13 14:13:32 +08:00
f6047aafe8 feat(web): metric descriptions 2026-04-13 14:13:04 +08:00
dce5715982 Merge branch 'main' into jzh 2026-04-13 13:35:13 +08:00
ea910b8e7d Merge branch 'main' into jzh 2026-04-13 10:29:55 +08:00
e51af66d95 feat(web): support creator filtering in apps & snippets 2026-04-12 12:19:17 +08:00
f93b287949 feat(web): billing for evaluation & snippets 2026-04-12 11:09:15 +08:00
627fbd2e86 feat(web): rag-pipeline evaluation configuration 2026-04-12 10:36:20 +08:00
e4c056a57a Merge branch 'main' into jzh 2026-04-12 10:16:13 +08:00
23291398ec feat(web): switch warining dialog 2026-04-10 18:13:13 +08:00
79fc352a5a feat(web): evaluation run detail 2026-04-10 17:48:28 +08:00
8b6b3cddea refactor(web): rag-pipeline evaluation 2026-04-10 17:31:10 +08:00
d1ca468c1e Merge branch 'main' into jzh 2026-04-10 17:12:57 +08:00
ce28ad771c feat(web): test history of rag-pipeline evaluation 2026-04-10 17:12:27 +08:00
ba951b01de feat(web): template download 2026-04-10 17:05:14 +08:00
670ab16ea1 feat(web): template download & upload & run in rag-pipeline 2026-04-10 16:51:06 +08:00
4680535ecd Merge branch 'main' into jzh 2026-04-10 13:49:26 +08:00
f96e63460e feat(web): save configuration 2026-04-10 13:49:05 +08:00
2df79c0404 refactor: input fields 2026-04-10 12:04:51 +08:00
acef9630d5 feat(web): input fields display 2026-04-10 11:39:35 +08:00
12c3b2e0cd feat(web): start run 2026-04-10 11:26:12 +08:00
577707ae50 refactor(web): input fields 2026-04-10 11:10:24 +08:00
03325e9750 feat(web): run history of batch test 2026-04-10 10:58:40 +08:00
a7ef8f9c12 Merge branch 'main' into jzh 2026-04-10 10:36:45 +08:00
40284d9f95 refactor(web): batch test of evaluation 2026-04-09 21:19:51 +08:00
5efe8b8bd7 feat(web): add condition 2026-04-09 21:08:45 +08:00
8dc6d736ee refactor(web): store of evaluation 2026-04-09 20:32:53 +08:00
5316372772 feat(web): judgement condition 2026-04-09 20:18:25 +08:00
4d1499ef75 refactor(web): refactor condition group 2026-04-09 19:46:18 +08:00
0438285277 Merge branch 'main' into jzh 2026-04-09 19:24:10 +08:00
4879ea5cd5 feat(web): support variable selecting in variable mapping 2026-04-09 19:23:22 +08:00
2a1761ac06 feat(web): add output 2026-04-09 18:16:30 +08:00
c29245c1cb feat(web): only one evaluation workflow can be added 2026-04-09 17:43:34 +08:00
5069694bba refactor(web): remove unused metric property 2026-04-09 17:30:10 +08:00
d1a80a85c0 refactor(web): evaluation configure schema update 2026-04-09 17:17:15 +08:00
5c93d74dec Merge branch 'main' into jzh 2026-04-09 15:36:00 +08:00
e52dbd49be feat(web): dataset evaluation configure 2026-04-09 15:34:59 +08:00
ccc8a5f278 refactor(web): dataset evaluation 2026-04-09 14:56:22 +08:00
cfb5b9dfea feat(web): dataset evaluation configure fetch 2026-04-09 14:21:01 +08:00
73d95245f8 feat(web): dataset evaluation layout 2026-04-09 13:44:29 +08:00
fb91984fcb feat(web): add evaluation navigation for rag-pipeline 2026-04-09 13:26:43 +08:00
29cb1fa12e Merge branch 'main' into jzh 2026-04-09 13:15:20 +08:00
78240ed199 Merge branch 'main' into jzh 2026-04-09 09:07:12 +08:00
8f8707fd77 Merge branch 'main' into jzh 2026-04-07 16:57:37 +08:00
ed3db06154 feat(web): restrictions of evalution workflow available nodes 2026-04-07 16:12:25 +08:00
7c05a68876 Merge branch 'main' into jzh 2026-04-07 14:41:42 +08:00
6cfc0dd8e1 Merge branch 'main' into jzh 2026-04-07 12:52:13 +08:00
81baeae5c4 fix(web): evaluation workflow switch 2026-04-03 18:22:44 +08:00
a3010bdc0b Merge branch 'main' into jzh 2026-04-03 18:05:54 +08:00
8133e550ed chore: fix pre-hook of web 2026-04-03 16:21:32 +08:00
2bb0eab636 chore(web): mapping row refactor 2026-04-03 16:10:41 +08:00
5311b5d00d feat(web): available evaluation workflow selector 2026-04-03 16:06:33 +08:00
9b02ccdd12 Merge branch 'main' into jzh 2026-04-03 15:15:11 +08:00
231783eebe chore(web): fix lint 2026-04-03 15:13:52 +08:00
756606f478 feat(web): hide card view in evaluation 2026-04-03 14:39:41 +08:00
6651c1c5da feat(web): workflow switch 2026-04-03 14:22:50 +08:00
61e257b2a8 feat(web): app switch api 2026-04-03 13:56:00 +08:00
3ac4caf735 Merge branch 'main' into jzh 2026-04-03 11:28:22 +08:00
268ae1751d Merge branch 'main' into jzh 2026-04-01 09:26:13 +08:00
015cbf850b Merge branch 'main' into jzh 2026-03-31 18:08:24 +08:00
873e13c2fb feat(web): support select node in metric card 2026-03-31 18:07:52 +08:00
688bf7e7a1 feat(web): metric card style 2026-03-31 17:43:56 +08:00
a6ffff3b39 fix(web): fix style of metric selector 2026-03-31 17:22:07 +08:00
023fc55bd5 fix(web): empty state of metric 2026-03-31 17:11:44 +08:00
351b909a53 feat(web): metric card 2026-03-31 17:00:37 +08:00
6bec4f65c9 refactor(web): metric section refactor 2026-03-31 16:28:48 +08:00
74f87ce152 Merge branch 'main' into jzh 2026-03-31 16:13:04 +08:00
92c472ccc7 Merge branch 'main' into jzh 2026-03-30 15:40:23 +08:00
b92b8becd1 feat(web): metric selector 2026-03-30 15:39:52 +08:00
23d0d6a65d chore(web): i18n of metrics 2026-03-30 14:20:43 +08:00
1660067d6e feat(web): judgement model selector 2026-03-30 14:03:37 +08:00
0642475b85 Merge branch 'main' into jzh 2026-03-30 13:30:10 +08:00
8cb634c9bc feat(web): evaluation layout 2026-03-30 11:27:06 +08:00
768b41c3cf Merge branch 'main' into jzh 2026-03-30 11:07:42 +08:00
ca88516d54 refactor(web): refactor evaluation page 2026-03-30 11:06:41 +08:00
871a2a149f refactor(web): split snippet index 2026-03-30 10:32:59 +08:00
60e381eff0 Merge branch 'main' into jzh 2026-03-30 09:48:58 +08:00
768b3eb6f9 feat(web): test run of snippet 2026-03-29 20:55:11 +08:00
2f88da4a6d feat(web): add variable inspect for snippet 2026-03-29 20:23:24 +08:00
a8cdf6964c feat(web): test run button 2026-03-29 20:02:59 +08:00
985c3db4fd feat(web): snippet input field panel layout 2026-03-29 18:02:27 +08:00
9636472db7 refactor(web): snippet main 2026-03-29 17:50:30 +08:00
0ad268aa7d feat(web): snippet publish 2026-03-29 17:29:37 +08:00
a4ea33167d feat(web): block selector in snippet 2026-03-29 17:01:32 +08:00
0f13aabea8 feat(web): input fields in snippet 2026-03-29 16:31:38 +08:00
1e76ef5ccb chore(web): ignore system vars & conversation vars in rag-pipeline and snippet 2026-03-29 15:56:24 +08:00
e6e3229d17 feat(web): input field button style 2026-03-29 15:45:05 +08:00
dccf8e723a feat(web): snippet version panel 2026-03-29 15:26:59 +08:00
c41ba7d627 feat(web): snippet header in graph 2026-03-29 15:02:34 +08:00
a6e9316de3 Merge branch 'main' into jzh 2026-03-29 14:07:49 +08:00
559d326cbd chore(web): mock data of snippet 2026-03-27 17:24:01 +08:00
abedf2506f Merge branch 'main' into jzh 2026-03-27 17:01:27 +08:00
d01428b5bc feat(web): snippet graph draft sync 2026-03-27 16:02:47 +08:00
0de1f17e5c Merge branch 'main' into jzh 2026-03-27 15:23:49 +08:00
17d07a5a43 feat(web): init snippet graph 2026-03-27 15:23:03 +08:00
3bdbea99a3 Merge branch 'main' into jzh 2026-03-27 14:04:10 +08:00
b7683aedb1 Merge branch 'main' into jzh 2026-03-26 21:38:48 +08:00
515036e758 test(web): add tests for snippets 2026-03-26 21:38:22 +08:00
22b382527f feat(web): add snippet to workflow 2026-03-26 21:26:29 +08:00
2cfe4b5b86 feat(web): snippet graph data fetching 2026-03-26 21:11:09 +08:00
6876c8041c feat(web): snippet list data fetching in block selector 2026-03-26 20:58:42 +08:00
7de45584ce refactor: snippets list 2026-03-26 20:41:51 +08:00
5572d7c7e8 Merge branch 'main' into jzh 2026-03-26 20:10:47 +08:00
db0a2fe52e Merge branch 'main' into jzh 2026-03-26 16:29:44 +08:00
f0ae8d6167 fix(web): unused imports caused by merge 2026-03-26 16:28:56 +08:00
2514e181ba Merge branch 'main' into jzh 2026-03-26 16:16:10 +08:00
be2e6e9a14 Merge branch 'main' into jzh 2026-03-26 14:23:29 +08:00
875e2eac1b Merge branch 'main' into jzh 2026-03-26 08:38:57 +08:00
c3c73ceb1f Merge branch 'main' into jzh 2026-03-25 23:02:18 +08:00
6318bf0a2a feat(web): create snippet from workflow 2026-03-25 22:57:48 +08:00
5e1f252046 feat(web): selection context menu style update 2026-03-25 22:36:27 +08:00
df3b960505 fix(web): position of selection context menu in workflow graph 2026-03-25 22:02:50 +08:00
26bc108bf1 chore(web): tests for snippet info 2026-03-25 21:35:36 +08:00
a5cff32743 feat(web): snippet info operations 2026-03-25 21:29:06 +08:00
d418dd8eec Merge branch 'main' into jzh 2026-03-25 20:17:32 +08:00
61702fe346 Merge branch 'main' into jzh 2026-03-25 18:17:03 +08:00
43f0c780c3 Merge branch 'main' into jzh 2026-03-25 15:30:21 +08:00
30ebf2bfa9 Merge branch 'main' into jzh 2026-03-24 07:25:22 +08:00
7e3027b5f7 feat(web): snippet card usage info 2026-03-23 17:02:00 +08:00
b3acf83090 Merge branch 'main' into jzh 2026-03-23 16:46:26 +08:00
36c3d6e48a feat(web): snippet list fetching & display 2026-03-23 16:37:05 +08:00
f782ac6b3c feat(web): create snippets by DSL import 2026-03-23 14:55:36 +08:00
feef2dd1fa feat(web): add snippet creation dialog flow 2026-03-23 11:29:41 +08:00
a716d8789d refactor: extract snippet list components 2026-03-23 10:48:15 +08:00
6816f89189 Merge branch 'main' into jzh 2026-03-23 10:13:45 +08:00
bfcac64a9d Merge branch 'main' into jzh 2026-03-20 15:33:49 +08:00
664eb601a2 feat(web): add api of snippet worfklows 2026-03-20 15:29:53 +08:00
8e5cc4e0aa feat(web): add evaluation api 2026-03-20 15:23:03 +08:00
9f28575903 feat(web): add snippets api 2026-03-20 15:11:33 +08:00
4b9a26a5e6 Merge branch 'main' into jzh 2026-03-20 14:01:34 +08:00
7b85adf1cc Merge branch 'main' into jzh 2026-03-20 10:46:45 +08:00
917d362a58 fix: Querying document list based on hit_count caused slow SQL. 2026-03-19 18:08:00 +08:00
3c27a90eb9 Merge remote-tracking branch 'origin/main'
# Conflicts:
#	api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py
2026-03-19 18:04:33 +08:00
c964708ebe Merge branch 'main' into jzh 2026-03-18 18:07:20 +08:00
883eb498c0 Merge branch 'main' into jzh 2026-03-18 17:40:51 +08:00
b85af2ec47 fix: When can not obtain pipeline template detail failed from upstream service including remote template service and database, return responding error message. 2026-03-18 11:20:50 +08:00
2f0f97aa66 Potential fix for pull request finding
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-03-18 11:09:28 +08:00
a6e03c6735 Potential fix for pull request finding
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-03-18 11:07:00 +08:00
e7cbfb89d6 Potential fix for pull request finding
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-03-18 11:05:39 +08:00
6c2decfbfb Apply suggestion from @gemini-code-assist[bot]
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-18 11:00:47 +08:00
b33d6d1d4a Merge branch 'main' into fix/issue-33624 2026-03-18 10:59:28 +08:00
1b32e70dc5 fix: When can not obtain pipeline template detail failed from upstream service including remote template service and database, return responding error message. 2026-03-18 10:51:07 +08:00
4d3738d225 Merge branch 'main' into feat/evaluation-fe 2026-03-17 10:42:44 +08:00
b5e90e77aa Merge remote-tracking branch 'origin/main' 2026-03-17 10:33:00 +08:00
dd0dee739d Merge branch 'main' into jzh 2026-03-16 15:43:20 +08:00
4d19914fcb Merge branch 'main' into feat/evaluation-fe 2026-03-16 10:47:37 +08:00
887c7710e9 feat: evaluation 2026-03-16 10:46:33 +08:00
7a722773c7 feat: snippet canvas 2026-03-13 17:45:04 +08:00
a763aff58b feat: snippets list 2026-03-13 16:12:42 +08:00
c1011f4e5c feat: add to snippet 2026-03-13 14:29:59 +08:00
f7afa103a5 feat: select snippets 2026-03-13 13:43:29 +08:00
d0bd5b473b Merge remote-tracking branch 'origin/main' 2026-03-03 15:37:42 +08:00
08b28b4029 fix: Add the validation of doc_form in the Document-related service APIs. 2026-03-03 14:59:51 +08:00
269bf883c2 fix: Add the validation of doc_form in the Document-related service APIs. 2026-03-03 14:31:51 +08:00
1265 changed files with 29173 additions and 55745 deletions

View File

@ -0,0 +1 @@
../../.agents/skills/frontend-query-mutation

View File

@ -1 +0,0 @@
../../.agents/skills/how-to-write-component

View File

@ -51,15 +51,6 @@ jobs:
with:
files: |
api/**
- name: Check dify-agent inputs
if: github.event_name != 'merge_group'
id: dify-agent-changes
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
dify-agent/**/*.py
dify-agent/pyproject.toml
dify-agent/uv.lock
- if: github.event_name != 'merge_group'
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
@ -85,17 +76,6 @@ jobs:
# Format code
uv run ruff format ..
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
run: |
cd dify-agent
uv sync --dev
# fmt first to avoid line too long
uv run ruff format .
# Fix lint errors
uv run ruff check --fix .
# Format code
uv run ruff format .
- name: count migration progress
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
run: |

View File

@ -1,415 +0,0 @@
name: CLI E2E Tests
on:
workflow_dispatch:
inputs:
cli_ref:
description: "Git ref (default: current branch)"
type: string
required: false
edition:
description: "Dify edition"
type: choice
required: false
default: ee
options: [ee, ce]
test_scope:
description: "smoke = [P0] only / full = all cases"
type: choice
required: false
default: full
options: [smoke, full]
# ── Suite on/off ────────────────────────────────────────────────────────
suite_framework_output_error:
description: "framework + output + error-handling suites"
type: boolean
default: true
suite_discovery:
description: "discovery suite (get app / describe app)"
type: boolean
default: true
suite_run:
description: "run suite (basic / streaming / conversation / file / hitl)"
type: boolean
default: true
suite_auth:
description: "auth suite (login / status / whoami / use / devices / logout)"
type: boolean
default: true
suite_agent:
description: "agent suite"
type: boolean
default: true
permissions:
contents: read
# ── Shared env injected into every E2E job ───────────────────────────────────
# Each job reads DIFY_E2E_TOKEN + app IDs from the provision job outputs,
# so global-setup skips minting and finds existing apps in < 10 s.
env:
DIFY_E2E_NO_KEYRING: "1" # Linux CI has no keychain; skip probe
VITEST_RETRY: "2" # Retry flaky staging responses
jobs:
# ════════════════════════════════════════════════════════════════════════════
# 0. PROVISION — mint token + import DSL fixtures (runs once, outputs IDs)
# ════════════════════════════════════════════════════════════════════════════
provision:
name: "Provision: mint token + DSL apps"
runs-on: ubuntu-latest
timeout-minutes: 10
outputs:
token: ${{ steps.out.outputs.DIFY_E2E_TOKEN }}
workspace_id: ${{ steps.out.outputs.DIFY_E2E_WORKSPACE_ID }}
workspace_name: ${{ steps.out.outputs.DIFY_E2E_WORKSPACE_NAME }}
ws2_id: ${{ steps.out.outputs.DIFY_E2E_WS2_ID }}
chat_app_id: ${{ steps.out.outputs.DIFY_E2E_CHAT_APP_ID }}
workflow_app_id: ${{ steps.out.outputs.DIFY_E2E_WORKFLOW_APP_ID }}
file_app_id: ${{ steps.out.outputs.DIFY_E2E_FILE_APP_ID }}
file_chat_app_id: ${{ steps.out.outputs.DIFY_E2E_FILE_CHAT_APP_ID }}
hitl_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_APP_ID }}
hitl_external_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_EXTERNAL_APP_ID }}
hitl_single_action_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_SINGLE_ACTION_APP_ID }}
hitl_multi_node_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_MULTI_NODE_APP_ID }}
ws2_app_id: ${{ steps.out.outputs.DIFY_E2E_WS2_APP_ID }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- uses: oven-sh/setup-bun@v2
with:
bun-version: latest
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
with:
package_json_field: packageManager
run_install: false
- name: Install CLI dependencies
working-directory: cli
run: pnpm install --frozen-lockfile
- name: Mint token & provision apps
id: out
working-directory: cli
env:
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
DIFY_E2E_TOKEN: ${{ secrets.DIFY_E2E_TOKEN }}
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
run: bun scripts/e2e-provision.ts
# ════════════════════════════════════════════════════════════════════════════
# 1-B. framework + output + error-handling (parallel with run/discovery)
# ════════════════════════════════════════════════════════════════════════════
suite-framework-output-error:
name: "Suite: framework + output + error-handling"
if: ${{ inputs.suite_framework_output_error != 'false' }}
needs: provision
runs-on: ubuntu-latest
timeout-minutes: 20
defaults:
run:
working-directory: cli
shell: bash
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- uses: ./.github/actions/setup-web
- uses: oven-sh/setup-bun@v2
with: { bun-version: latest }
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
with: { package_json_field: packageManager, run_install: false }
- run: pnpm install --frozen-lockfile
- run: pnpm tree:gen
- name: Run framework + output + error-handling
env:
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
DIFY_E2E_INCLUDE: "test/e2e/suites/framework/**/*.e2e.ts,test/e2e/suites/output/**/*.e2e.ts,test/e2e/suites/error-handling/**/*.e2e.ts"
run: |
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
pnpm test:e2e -- -t "\[P0\]"
else
pnpm test:e2e
fi
# ════════════════════════════════════════════════════════════════════════════
# 1-C. Discovery (parallel)
# ════════════════════════════════════════════════════════════════════════════
suite-discovery:
name: "Suite: discovery"
if: ${{ inputs.suite_discovery != 'false' }}
needs: provision
runs-on: ubuntu-latest
timeout-minutes: 20
defaults:
run:
working-directory: cli
shell: bash
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- uses: ./.github/actions/setup-web
- uses: oven-sh/setup-bun@v2
with: { bun-version: latest }
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
with: { package_json_field: packageManager, run_install: false }
- run: pnpm install --frozen-lockfile
- run: pnpm tree:gen
- name: Run discovery suite
env:
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
DIFY_E2E_WS2_ID: ${{ needs.provision.outputs.ws2_id }}
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
DIFY_E2E_INCLUDE: "test/e2e/suites/discovery/**/*.e2e.ts"
run: |
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
pnpm test:e2e -- -t "\[P0\]"
else
pnpm test:e2e
fi
# ════════════════════════════════════════════════════════════════════════════
# 1-D. Run suite — 5 files in matrix (parallel)
# ════════════════════════════════════════════════════════════════════════════
suite-run:
name: "Suite: run / ${{ matrix.name }}"
if: ${{ inputs.suite_run != 'false' }}
needs: provision
runs-on: ubuntu-latest
timeout-minutes: 20
strategy:
fail-fast: false
matrix:
include:
- name: basic
file: run-app-basic.e2e.ts
- name: streaming
file: run-app-streaming.e2e.ts
- name: conversation
file: run-app-conversation.e2e.ts
- name: file
file: run-app-file.e2e.ts
- name: hitl
file: run-app-hitl.e2e.ts
defaults:
run:
working-directory: cli
shell: bash
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- uses: ./.github/actions/setup-web
- uses: oven-sh/setup-bun@v2
with: { bun-version: latest }
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
with: { package_json_field: packageManager, run_install: false }
- run: pnpm install --frozen-lockfile
- run: pnpm tree:gen
- name: "Run run/${{ matrix.name }}"
env:
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
DIFY_E2E_SSO_TOKEN: ${{ secrets.DIFY_E2E_SSO_TOKEN }}
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
DIFY_E2E_FILE_APP_ID: ${{ needs.provision.outputs.file_app_id }}
DIFY_E2E_FILE_CHAT_APP_ID: ${{ needs.provision.outputs.file_chat_app_id }}
DIFY_E2E_HITL_APP_ID: ${{ needs.provision.outputs.hitl_app_id }}
DIFY_E2E_HITL_EXTERNAL_APP_ID: ${{ needs.provision.outputs.hitl_external_app_id }}
DIFY_E2E_HITL_SINGLE_ACTION_APP_ID: ${{ needs.provision.outputs.hitl_single_action_app_id }}
DIFY_E2E_HITL_MULTI_NODE_APP_ID: ${{ needs.provision.outputs.hitl_multi_node_app_id }}
DIFY_E2E_INCLUDE: "test/e2e/suites/run/${{ matrix.file }}"
run: |
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
pnpm test:e2e -- -t "\[P0\]"
else
pnpm test:e2e
fi
- name: Upload results on failure
if: failure()
uses: actions/upload-artifact@v4
with:
name: e2e-run-${{ matrix.name }}-${{ github.run_id }}
path: cli/test-results/
retention-days: 3
# ════════════════════════════════════════════════════════════════════════════
# 1-E. auth/login + status + whoami (parallel, read-only, safe)
# ════════════════════════════════════════════════════════════════════════════
suite-auth-safe:
name: "Suite: auth (login / status / whoami)"
if: ${{ inputs.suite_auth != 'false' }}
needs: provision
runs-on: ubuntu-latest
timeout-minutes: 15
defaults:
run:
working-directory: cli
shell: bash
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- uses: ./.github/actions/setup-web
- uses: oven-sh/setup-bun@v2
with: { bun-version: latest }
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
with: { package_json_field: packageManager, run_install: false }
- run: pnpm install --frozen-lockfile
- run: pnpm tree:gen
- name: Run auth/login + status + whoami
env:
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
DIFY_E2E_WS2_ID: ${{ needs.provision.outputs.ws2_id }}
DIFY_E2E_INCLUDE: "test/e2e/suites/auth/login.e2e.ts,test/e2e/suites/auth/status.e2e.ts,test/e2e/suites/auth/whoami.e2e.ts"
run: |
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
pnpm test:e2e -- -t "\[P0\]"
else
pnpm test:e2e
fi
# ════════════════════════════════════════════════════════════════════════════
# 2. DESTRUCTIVE — auth/use + devices + logout + agent (serial, runs LAST)
# Must wait for ALL parallel suites to finish to avoid token revocation
# invalidating other in-flight requests.
# ════════════════════════════════════════════════════════════════════════════
suite-last:
name: "Suite: auth-use + devices + logout + agent (last, serial)"
# Runs when auth is selected; also runs after all parallel jobs finish
if: ${{ inputs.suite_auth != 'false' || inputs.suite_agent != 'false' }}
needs:
- provision
- suite-framework-output-error
- suite-discovery
- suite-run
- suite-auth-safe
# `needs` on a skipped job is treated as success — safe to proceed even if
# some suites were disabled via toggle.
runs-on: ubuntu-latest
timeout-minutes: 25
defaults:
run:
working-directory: cli
shell: bash
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- uses: ./.github/actions/setup-web
- uses: oven-sh/setup-bun@v2
with: { bun-version: latest }
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
with: { package_json_field: packageManager, run_install: false }
- run: pnpm install --frozen-lockfile
- run: pnpm tree:gen
- name: Run use / devices / logout / agent (serial)
env:
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
DIFY_E2E_WS2_ID: ${{ needs.provision.outputs.ws2_id }}
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
DIFY_E2E_HITL_APP_ID: ${{ needs.provision.outputs.hitl_app_id }}
DIFY_E2E_HITL_EXTERNAL_APP_ID: ${{ needs.provision.outputs.hitl_external_app_id }}
DIFY_E2E_HITL_SINGLE_ACTION_APP_ID: ${{ needs.provision.outputs.hitl_single_action_app_id }}
DIFY_E2E_HITL_MULTI_NODE_APP_ID: ${{ needs.provision.outputs.hitl_multi_node_app_id }}
run: |
# Collect files in safe order: use → devices → logout (revokes last) → agent
FILES=()
if [ "${{ inputs.suite_auth }}" = "true" ]; then
FILES+=(
test/e2e/suites/auth/use.e2e.ts
test/e2e/suites/auth/devices.e2e.ts
test/e2e/suites/auth/logout.e2e.ts
)
fi
if [ "${{ inputs.suite_agent }}" = "true" ]; then
while IFS= read -r f; do FILES+=("$f"); done \
< <(find test/e2e/suites/agent -name '*.e2e.ts' | sort)
fi
[ ${#FILES[@]} -eq 0 ] && { echo "Nothing to run."; exit 0; }
# Pass files via DIFY_E2E_INCLUDE (comma-separated) so vitest
# config's include list is overridden instead of ANDed.
INCLUDE=$(IFS=,; echo "${FILES[*]}")
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
DIFY_E2E_INCLUDE="$INCLUDE" pnpm test:e2e -- -t "\[P0\]"
else
DIFY_E2E_INCLUDE="$INCLUDE" pnpm test:e2e
fi
- name: Upload results on failure
if: failure()
uses: actions/upload-artifact@v4
with:
name: e2e-last-${{ github.run_id }}
path: cli/test-results/
retention-days: 3

3
.gitignore vendored
View File

@ -259,6 +259,3 @@ scripts/stress-test/reports/
.qoder/*
.context/
.eslintcache
# Vitest local reports
web/.vitest-reports/

View File

@ -17,7 +17,7 @@ FROM base AS packages
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
# basic environment
git g++ \
g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
@ -97,6 +97,7 @@ RUN \
# Copy Python environment and packages
ENV VIRTUAL_ENV=/app/api/.venv
COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY --from=packages --chown=dify:dify /app/dify-agent /app/dify-agent
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data

View File

@ -34,7 +34,6 @@ from clients.agent_backend.request_builder import (
DIFY_PLUGIN_TOOLS_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
AgentBackendAgentAppRunInput,
AgentBackendModelConfig,
AgentBackendOutputConfig,
AgentBackendRunRequestBuilder,
@ -50,7 +49,6 @@ __all__ = [
"DIFY_PLUGIN_TOOLS_LAYER_ID",
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
"WORKFLOW_USER_PROMPT_LAYER_ID",
"AgentBackendAgentAppRunInput",
"AgentBackendError",
"AgentBackendHTTPError",
"AgentBackendInternalEvent",

View File

@ -30,7 +30,6 @@ from dify_agent.layers.execution_context import (
DifyExecutionContextLayerConfig,
)
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellLayerConfig
from dify_agent.protocol import (
DIFY_AGENT_HISTORY_LAYER_ID,
DIFY_AGENT_MODEL_LAYER_ID,
@ -46,10 +45,8 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt"
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
DIFY_SHELL_LAYER_ID = "shell"
# Layer types that hold credentials in their per-run config. These are excluded
# from the cleanup-replay composition (and from the snapshot that is sent with
@ -169,10 +166,6 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
idempotency_key: str | None = None
output: AgentBackendOutputConfig | None = None
tools: DifyPluginToolsLayerConfig | None = None
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
include_shell: bool = False
shell_config: DifyShellLayerConfig | None = None
session_snapshot: CompositorSessionSnapshot | None = None
include_history: bool = True
suspend_on_exit: bool = True
@ -188,154 +181,9 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
return value
class AgentBackendAgentAppRunInput(BaseModel):
"""Inputs to build one Agent App conversation-turn run request.
Unlike the workflow-node input there is no workflow-node-job prompt and no
previous-node context: the user prompt is the chat message, and multi-turn
continuity comes from ``session_snapshot`` + the history layer keyed by the
conversation.
"""
model: AgentBackendModelConfig
execution_context: DifyExecutionContextLayerConfig
user_prompt: str
agent_soul_prompt: str | None = None
purpose: RunPurpose = "agent_app"
idempotency_key: str | None = None
output: AgentBackendOutputConfig | None = None
tools: DifyPluginToolsLayerConfig | None = None
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
include_shell: bool = False
shell_config: DifyShellLayerConfig | None = None
session_snapshot: CompositorSessionSnapshot | None = None
include_history: bool = True
suspend_on_exit: bool = True
metadata: dict[str, JsonValue] = Field(default_factory=dict)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@field_validator("user_prompt")
@classmethod
def _reject_blank_prompt(cls, value: str) -> str:
if not value.strip():
raise ValueError("prompt must not be blank")
return value
class AgentBackendRunRequestBuilder:
"""Converts API product state into the public ``dify-agent`` run protocol."""
def build_for_agent_app(self, run_input: AgentBackendAgentAppRunInput) -> CreateRunRequest:
"""Build an Agent App conversation-turn run request.
Layer graph: optional Agent Soul system prompt → user prompt →
execution context → optional history (multi-turn) → LLM → optional
plugin tools → optional structured output. Mirrors the workflow-node
layer ordering minus the workflow-job / previous-node prompt.
"""
layers: list[RunLayerSpec] = []
if run_input.agent_soul_prompt:
layers.append(
RunLayerSpec(
name=AGENT_SOUL_PROMPT_LAYER_ID,
type=PLAIN_PROMPT_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_soul"},
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
)
)
layers.extend(
[
RunLayerSpec(
name=AGENT_APP_USER_PROMPT_LAYER_ID,
type=PLAIN_PROMPT_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_app_user_prompt"},
config=PromptLayerConfig(user=run_input.user_prompt),
),
RunLayerSpec(
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.execution_context,
),
]
)
if run_input.include_history:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_HISTORY_LAYER_ID,
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_session_history"},
)
)
layers.append(
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=DifyPluginLLMLayerConfig(
plugin_id=run_input.model.plugin_id,
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
model_settings=run_input.model.model_settings or None,
),
)
)
if run_input.tools is not None and run_input.tools.tools:
layers.append(
RunLayerSpec(
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=run_input.tools,
)
)
if run_input.include_shell:
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
# so the spec carries no deps; shellctl connection is server-injected.
layers.append(
RunLayerSpec(
name=DIFY_SHELL_LAYER_ID,
type=DIFY_SHELL_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.shell_config or DifyShellLayerConfig(),
)
)
if run_input.output is not None:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_OUTPUT_LAYER_ID,
type=DIFY_OUTPUT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=DifyOutputLayerConfig(
json_schema=run_input.output.json_schema,
description=run_input.output.description,
strict=run_input.output.strict,
),
)
)
return CreateRunRequest(
composition=RunComposition(layers=layers),
purpose=run_input.purpose,
idempotency_key=run_input.idempotency_key,
metadata=run_input.metadata,
session_snapshot=run_input.session_snapshot,
on_exit=LayerExitSignals(
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
),
)
def build_cleanup_request(
self,
*,
@ -454,18 +302,6 @@ class AgentBackendRunRequestBuilder:
)
)
if run_input.include_shell:
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
# so the spec carries no deps; shellctl connection is server-injected.
layers.append(
RunLayerSpec(
name=DIFY_SHELL_LAYER_ID,
type=DIFY_SHELL_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.shell_config or DifyShellLayerConfig(),
)
)
if run_input.output is not None:
layers.append(
RunLayerSpec(

View File

@ -1,135 +0,0 @@
"""API-side client for the agent backend's read-only workspace file endpoints.
The agent backend exposes ``/workspaces/{session_id}/files{,/preview,/download}``
to inspect a shell-layer sandbox workspace. This thin synchronous client proxies
those reads for the console FS inspector and normalizes transport/HTTP failures
into the API backend's ``AgentBackendError`` boundary, preserving the backend's
status code and ``{code, message}`` detail so the controller can relay them.
"""
from __future__ import annotations
import base64
import binascii
from dataclasses import dataclass
from typing import Literal
import httpx
from pydantic import BaseModel
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
_DEFAULT_TIMEOUT_SECONDS = 30.0
class WorkspaceFileEntry(BaseModel):
"""One entry in a workspace directory listing."""
name: str
type: Literal["file", "dir", "symlink"]
size: int
mtime: int
class WorkspaceListResult(BaseModel):
"""Directory listing of a workspace path."""
path: str
entries: list[WorkspaceFileEntry]
truncated: bool
class WorkspacePreviewResult(BaseModel):
"""Inline preview of a workspace file."""
path: str
size: int
truncated: bool
binary: bool
text: str | None = None
@dataclass(frozen=True, slots=True)
class WorkspaceDownloadResult:
"""Decoded bytes of a workspace file for download."""
path: str
size: int
truncated: bool
content: bytes
class WorkspaceFilesBackendClient:
"""Synchronous proxy to the agent backend workspace file endpoints."""
def __init__(
self,
base_url: str,
*,
timeout: float = _DEFAULT_TIMEOUT_SECONDS,
transport: httpx.BaseTransport | None = None,
) -> None:
self._base_url = base_url.rstrip("/")
self._timeout = timeout
self._transport = transport
def list_files(self, session_id: str, path: str) -> WorkspaceListResult:
data = self._get(f"/workspaces/{session_id}/files", params={"path": path})
return WorkspaceListResult.model_validate(data)
def preview(self, session_id: str, path: str) -> WorkspacePreviewResult:
data = self._get(f"/workspaces/{session_id}/files/preview", params={"path": path})
return WorkspacePreviewResult.model_validate(data)
def download(self, session_id: str, path: str) -> WorkspaceDownloadResult:
data = self._get(f"/workspaces/{session_id}/files/download", params={"path": path})
encoded = data.get("content_base64")
if not isinstance(encoded, str):
raise AgentBackendHTTPError("agent backend download response missing content", status_code=502, detail=data)
try:
content = base64.b64decode(encoded, validate=True)
except (binascii.Error, ValueError) as exc:
raise AgentBackendHTTPError(
"agent backend returned undecodable download content", status_code=502, detail=str(exc)
) from exc
size = data.get("size")
return WorkspaceDownloadResult(
path=str(data.get("path", path)),
size=int(size) if isinstance(size, (int, float)) else len(content),
truncated=bool(data.get("truncated")),
content=content,
)
def _get(self, route: str, *, params: dict[str, str]) -> dict[str, object]:
url = f"{self._base_url}{route}"
try:
with httpx.Client(timeout=self._timeout, transport=self._transport, trust_env=False) as client:
response = client.get(url, params=params)
except httpx.HTTPError as exc:
raise AgentBackendTransportError(f"failed to reach agent backend workspace endpoint: {exc}") from exc
if response.status_code >= 400:
detail: object
try:
detail = response.json().get("detail", response.text)
except ValueError:
detail = response.text
raise AgentBackendHTTPError(
f"agent backend workspace request failed ({response.status_code})",
status_code=response.status_code,
detail=detail,
)
body = response.json()
if not isinstance(body, dict):
raise AgentBackendHTTPError(
"agent backend workspace response was not an object", status_code=502, detail=body
)
return body
__all__ = [
"WorkspaceDownloadResult",
"WorkspaceFileEntry",
"WorkspaceFilesBackendClient",
"WorkspaceListResult",
"WorkspacePreviewResult",
]

View File

@ -29,7 +29,6 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
@override
def __call__(self) -> dict[str, Any]:
current_state = self.current_state
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")

View File

@ -21,13 +21,3 @@ class AgentBackendConfig(BaseSettings):
description="Scenario used by the fake Agent backend client.",
default="success",
)
AGENT_SHELL_ENABLED: bool = Field(
description=(
"Inject the dify.shell layer (sandboxed bash workspace) into Agent runs. "
"Requires the agent backend to be wired with a shellctl entrypoint; keep it "
"off until shellctl is deployed, otherwise every agent run that includes the "
"shell layer will fail."
),
default=False,
)

View File

@ -81,15 +81,4 @@ default_app_templates: Mapping[AppMode, Mapping] = {
},
},
},
# agent default mode (new Agent App type). The runtime model / prompt / tools
# come from the bound Agent Soul snapshot, so no model_config is seeded in the
# template; create_app still creates a model-less app_model_config row to hold
# app-level presentation features (opener, follow-up, citations, ...).
AppMode.AGENT: {
"app": {
"mode": AppMode.AGENT,
"enable_site": True,
"enable_api": True,
},
},
}

View File

@ -1,40 +1,10 @@
import json
from pydantic import BaseModel, Field, JsonValue
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
"decision": "approve",
"attachment": {
"transfer_method": "local_file",
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
"type": "document",
},
"attachments": [
{
"transfer_method": "local_file",
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
"type": "document",
},
{
"transfer_method": "remote_url",
"url": "https://example.com/report.pdf",
"type": "document",
},
],
}
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue] = Field(
description=(
"Submitted human input values keyed by output variable name. "
"Use a string for paragraph or select input values, a file mapping for file inputs, "
"and a list of file mappings for file-list inputs. Local file mappings use "
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
"`transfer_method=remote_url` with `url` or `remote_url`."
),
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
)
inputs: dict[str, JsonValue]
action: str

View File

@ -36,8 +36,6 @@ QueryParamDoc = TypedDict(
},
)
JsonResponseWithStatus = tuple[dict[str, Any], int]
class QueryArgs(Protocol):
def to_dict(self, flat: bool = True) -> dict[str, str]: ...

View File

@ -51,9 +51,6 @@ from .agent import roster as agent_roster
from .app import (
advanced_prompt_template,
agent,
agent_app_access,
agent_app_feature,
agent_app_workspace,
annotation,
app,
audio,
@ -122,6 +119,7 @@ from .explore import (
saved_message,
trial,
)
from .snippets import snippet_workflow, snippet_workflow_draft_variable
from .socketio import workflow as socketio_workflow
# Import tag controllers
@ -137,6 +135,7 @@ from .workspace import (
model_providers,
models,
plugin,
snippets,
tool_providers,
trigger_providers,
workspace,
@ -149,9 +148,6 @@ __all__ = [
"activate",
"advanced_prompt_template",
"agent",
"agent_app_access",
"agent_app_feature",
"agent_app_workspace",
"agent_composer",
"agent_providers",
"agent_roster",
@ -212,6 +208,9 @@ __all__ = [
"saved_message",
"setup",
"site",
"snippet_workflow",
"snippet_workflow_draft_variable",
"snippets",
"socketio_workflow",
"spec",
"statistic",

View File

@ -3,13 +3,7 @@ from flask_restx import Resource
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.agent_fields import (
AgentAppComposerResponse,
AgentComposerCandidatesResponse,
@ -18,7 +12,7 @@ from fields.agent_fields import (
WorkflowAgentComposerResponse,
)
from libs.helper import dump_response
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.model import App, AppMode
from services.agent.composer_service import AgentComposerService
from services.agent.composer_validator import ComposerConfigValidator
@ -44,8 +38,8 @@ class WorkflowAgentComposerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App, node_id: str):
def get(self, app_model: App, node_id: str):
_, tenant_id = current_account_with_tenant()
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.load_workflow_composer(
@ -64,9 +58,8 @@ class WorkflowAgentComposerApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
def put(self, app_model: App, node_id: str):
account, tenant_id = current_account_with_tenant()
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return dump_response(
WorkflowAgentComposerResponse,
@ -74,7 +67,7 @@ class WorkflowAgentComposerApi(Resource):
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account_id,
account_id=account.id,
payload=payload,
),
)
@ -120,8 +113,8 @@ class WorkflowAgentComposerImpactApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@with_current_tenant_id
def post(self, tenant_id: str, app_model: App, node_id: str):
def post(self, app_model: App, node_id: str):
_, tenant_id = current_account_with_tenant()
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
if not current_snapshot_id:
@ -145,9 +138,8 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
def post(self, app_model: App, node_id: str):
account, tenant_id = current_account_with_tenant()
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return dump_response(
WorkflowAgentComposerResponse,
@ -155,7 +147,7 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account_id,
account_id=account.id,
payload=payload,
),
)
@ -168,8 +160,8 @@ class AgentAppComposerApi(Resource):
@login_required
@account_initialization_required
@get_app_model()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
return dump_response(
AgentAppComposerResponse,
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
@ -182,16 +174,15 @@ class AgentAppComposerApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model()
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App):
def put(self, app_model: App):
account, tenant_id = current_account_with_tenant()
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return dump_response(
AgentAppComposerResponse,
AgentComposerService.save_agent_app_composer(
tenant_id=tenant_id,
app_id=app_model.id,
account_id=account_id,
account_id=account.id,
payload=payload,
),
)

View File

@ -6,13 +6,7 @@ from pydantic import BaseModel, Field
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.agent_fields import (
AgentConfigSnapshotDetailResponse,
@ -22,7 +16,7 @@ from fields.agent_fields import (
AgentRosterResponse,
)
from libs.helper import dump_response
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.agent.roster_service import AgentRosterService
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
@ -64,8 +58,8 @@ class AgentRosterListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str):
def get(self):
_, tenant_id = current_account_with_tenant()
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
return dump_response(
AgentRosterListResponse,
@ -80,12 +74,11 @@ class AgentRosterListApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str):
def post(self):
account, tenant_id = current_account_with_tenant()
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
service = _agent_roster_service()
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
return dump_response(
AgentRosterResponse,
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
@ -99,8 +92,8 @@ class AgentInviteOptionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str):
def get(self):
_, tenant_id = current_account_with_tenant()
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
return dump_response(
AgentInviteOptionsResponse,
@ -120,8 +113,8 @@ class AgentRosterDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
return dump_response(
AgentRosterResponse,
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
@ -133,14 +126,13 @@ class AgentRosterDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user_id
@with_current_tenant_id
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
def patch(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
return dump_response(
AgentRosterResponse,
_agent_roster_service().update_roster_agent(
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
),
)
@ -149,10 +141,9 @@ class AgentRosterDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user_id
@with_current_tenant_id
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
def delete(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
return "", 204
@ -162,8 +153,8 @@ class AgentRosterVersionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
return dump_response(
AgentConfigSnapshotListResponse,
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
@ -176,8 +167,8 @@ class AgentRosterVersionDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
def get(self, agent_id: UUID, version_id: UUID):
_, tenant_id = current_account_with_tenant()
return dump_response(
AgentConfigSnapshotDetailResponse,
_agent_roster_service().get_agent_version_detail(

View File

@ -1,59 +0,0 @@
"""Agent App access & sharing endpoints (read-only workflow references).
An Agent App is backed by a roster Agent that workflow Agent nodes may also
reference. This exposes the read-only "Workflow access" surface from the PRD:
which workflow apps use this Agent, without leaking the workflows' internals.
"""
from flask_restx import Resource
from pydantic import Field
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import login_required
from models.model import App, AppMode
from services.agent.roster_service import AgentRosterService
class AgentReferencingWorkflowResponse(ResponseModel):
app_id: str
app_name: str
app_mode: str
workflow_id: str
node_ids: list[str] = Field(default_factory=list)
class AgentReferencingWorkflowsResponse(ResponseModel):
data: list[AgentReferencingWorkflowResponse] = Field(default_factory=list)
register_response_schema_models(console_ns, AgentReferencingWorkflowsResponse)
@console_ns.route("/apps/<uuid:app_id>/agent-referencing-workflows")
class AgentAppReferencingWorkflowsResource(Resource):
@console_ns.doc("list_agent_app_referencing_workflows")
@console_ns.doc(description="List workflow apps that reference this Agent App's bound Agent (read-only)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
"Referencing workflows listed successfully",
console_ns.models[AgentReferencingWorkflowsResponse.__name__],
)
@console_ns.response(404, "App not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
tenant_id=tenant_id, app_id=app_model.id
)
return AgentReferencingWorkflowsResponse(
data=[AgentReferencingWorkflowResponse.model_validate(workflow) for workflow in workflows]
).model_dump(mode="json")

View File

@ -1,93 +0,0 @@
"""Agent App presentation-feature configuration endpoint.
The new Agent App type keeps model / prompt / tools in its bound Agent Soul, so
the legacy ``/model-config`` surface (which writes model, prompt and agent tool
config) is the wrong place to configure its app-level presentation features.
This endpoint exposes only the PRD "Misc Legacy" feature subset — conversation
opener, follow-up suggestions, citations, content moderation and speech — and
persists them onto the app's ``app_model_config`` without touching anything the
Soul owns.
"""
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from events.app_event import app_model_config_was_updated
from libs.helper import dump_response
from libs.login import login_required
from models import Account
from models.agent_config_entities import (
AgentFeatureToggleConfig,
AgentSensitiveWordAvoidanceFeatureConfig,
AgentSuggestedQuestionsAfterAnswerFeatureConfig,
AgentTextToSpeechFeatureConfig,
)
from models.model import App, AppMode
from services.agent_app_feature_service import AgentAppFeatureConfigService
class AgentAppFeaturesPayload(BaseModel):
"""Presentation features configurable on an Agent App.
All fields are optional; an omitted field is reset to its disabled/empty
default (the config form sends the full desired feature state on save).
"""
opening_statement: str | None = Field(default=None, description="Conversation opener shown before the first turn")
suggested_questions: list[str] | None = Field(
default=None, description="Preset questions shown alongside the opener"
)
suggested_questions_after_answer: AgentSuggestedQuestionsAfterAnswerFeatureConfig | None = Field(
default=None, description="Follow-up suggestions config, e.g. {'enabled': true}"
)
speech_to_text: AgentFeatureToggleConfig | None = Field(default=None, description="Speech-to-text config")
text_to_speech: AgentTextToSpeechFeatureConfig | None = Field(default=None, description="Text-to-speech config")
retriever_resource: AgentFeatureToggleConfig | None = Field(
default=None, description="Citations / attributions config, e.g. {'enabled': true}"
)
sensitive_word_avoidance: AgentSensitiveWordAvoidanceFeatureConfig | None = Field(
default=None, description="Content moderation config"
)
register_schema_models(console_ns, AgentAppFeaturesPayload)
register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/apps/<uuid:app_id>/agent-features")
class AgentAppFeatureConfigResource(Resource):
@console_ns.doc("update_agent_app_features")
@console_ns.doc(description="Update an Agent App's presentation features (opener, follow-up, citations, ...)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AgentAppFeaturesPayload.__name__])
@console_ns.response(200, "Features updated successfully", console_ns.models[SimpleResultResponse.__name__])
@console_ns.response(400, "Invalid configuration")
@console_ns.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
@with_current_user
def post(self, current_user: Account, app_model: App):
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
new_app_model_config = AgentAppFeatureConfigService.update_features(
app_model=app_model,
account=current_user,
config=args.model_dump(exclude_none=True),
)
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
return dump_response(SimpleResultResponse, {"result": "success"})

View File

@ -1,319 +0,0 @@
"""Agent App sandbox file-system inspector (read-only).
Exposes the PRD "rc1-like sandbox file system, downloadable not editable" view
for an Agent App conversation: list a directory, preview a file, or download a
file from the conversation's shell-layer workspace. The API never touches
shellctl directly — it resolves the conversation's sandbox ``session_id`` from
the stored session snapshot and proxies to the agent backend's read-only
workspace endpoints.
"""
from typing import Literal
from uuid import UUID
from flask import Response
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
from clients.agent_backend.workspace_files_client import WorkspaceDownloadResult
from controllers.common.schema import (
query_params_from_model,
query_params_from_request,
register_response_schema_models,
)
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.model import App, AppMode
from services.agent_app_workspace_service import (
AgentAppWorkspaceService,
AgentWorkspaceInspectorError,
WorkflowAgentWorkspaceService,
)
class _WorkspaceFileDownloadField(fields.Raw):
__schema_type__ = "string"
__schema_format__ = "binary"
class AgentWorkspaceListQuery(BaseModel):
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
class AgentWorkspaceFileQuery(BaseModel):
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
class WorkflowAgentWorkspaceListQuery(BaseModel):
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
node_execution_id: str | None = Field(
default=None,
description=(
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
),
)
class WorkflowAgentWorkspaceFileQuery(BaseModel):
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
node_execution_id: str | None = Field(
default=None,
description=(
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
),
)
class WorkspaceFileEntryResponse(ResponseModel):
name: str
type: Literal["file", "dir", "symlink"]
size: int
mtime: int
class WorkspaceListResponse(ResponseModel):
path: str
entries: list[WorkspaceFileEntryResponse] = Field(default_factory=list)
truncated: bool = False
class WorkspacePreviewResponse(ResponseModel):
path: str
size: int
truncated: bool
binary: bool
text: str | None = None
register_response_schema_models(console_ns, WorkspaceListResponse)
register_response_schema_models(console_ns, WorkspacePreviewResponse)
def _handle(exc: Exception) -> tuple[dict[str, object], int]:
if isinstance(exc, AgentWorkspaceInspectorError):
return {"code": exc.code, "message": exc.message}, exc.status_code
if isinstance(exc, AgentBackendHTTPError):
detail = exc.detail
if isinstance(detail, dict):
return {
"code": detail.get("code", "agent_backend_error"),
"message": detail.get("message", str(exc)),
}, exc.status_code
return {"code": "agent_backend_error", "message": str(detail)}, exc.status_code
if isinstance(exc, AgentBackendTransportError):
return {"code": "agent_backend_unreachable", "message": str(exc)}, 502
raise exc
def _download_response(result: WorkspaceDownloadResult) -> Response | tuple[dict[str, object], int]:
if result.truncated:
return {
"code": "workspace_file_too_large",
"message": (
"file exceeds the workspace download limit; use preview for partial text or download a smaller file"
),
"size": result.size,
}, 413
filename = result.path.rsplit("/", 1)[-1] or "download"
return Response(
result.content,
mimetype="application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(len(result.content)),
"X-Workspace-File-Size": str(result.size),
},
)
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files")
class AgentAppWorkspaceListResource(Resource):
@console_ns.doc("list_agent_app_workspace_files")
@console_ns.doc(description="List a directory in an Agent App conversation's sandbox workspace (read-only)")
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceListQuery)})
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(AgentWorkspaceListQuery)
try:
result = AgentAppWorkspaceService().list_files(
tenant_id=tenant_id,
app_id=app_model.id,
conversation_id=query.conversation_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return result.model_dump()
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/preview")
class AgentAppWorkspacePreviewResource(Resource):
@console_ns.doc("preview_agent_app_workspace_file")
@console_ns.doc(description="Preview a text/binary file in an Agent App conversation's sandbox workspace")
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(AgentWorkspaceFileQuery)
try:
result = AgentAppWorkspaceService().preview(
tenant_id=tenant_id,
app_id=app_model.id,
conversation_id=query.conversation_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return result.model_dump()
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/download")
class AgentAppWorkspaceDownloadResource(Resource):
@console_ns.doc("download_agent_app_workspace_file")
@console_ns.doc(description="Download a file from an Agent App conversation's sandbox workspace (read-only)")
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
@console_ns.doc(produces=["application/octet-stream"])
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
@console_ns.response(413, "File exceeds the workspace download limit")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(AgentWorkspaceFileQuery)
try:
result = AgentAppWorkspaceService().download(
tenant_id=tenant_id,
app_id=app_model.id,
conversation_id=query.conversation_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return _download_response(result)
@console_ns.route(
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files"
)
class WorkflowAgentWorkspaceListResource(Resource):
@console_ns.doc("list_workflow_agent_workspace_files")
@console_ns.doc(description="List a directory in a Workflow Agent node's sandbox workspace (read-only)")
@console_ns.doc(
params={
"app_id": "Application ID",
"workflow_run_id": "Workflow run ID",
"node_id": "Workflow Agent node ID",
**query_params_from_model(WorkflowAgentWorkspaceListQuery),
}
)
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(WorkflowAgentWorkspaceListQuery)
try:
result = WorkflowAgentWorkspaceService().list_files(
tenant_id=tenant_id,
app_id=app_model.id,
workflow_run_id=str(workflow_run_id),
node_id=node_id,
node_execution_id=query.node_execution_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return result.model_dump()
@console_ns.route(
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/preview"
)
class WorkflowAgentWorkspacePreviewResource(Resource):
@console_ns.doc("preview_workflow_agent_workspace_file")
@console_ns.doc(description="Preview a text/binary file in a Workflow Agent node's sandbox workspace")
@console_ns.doc(
params={
"app_id": "Application ID",
"workflow_run_id": "Workflow run ID",
"node_id": "Workflow Agent node ID",
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
}
)
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
try:
result = WorkflowAgentWorkspaceService().preview(
tenant_id=tenant_id,
app_id=app_model.id,
workflow_run_id=str(workflow_run_id),
node_id=node_id,
node_execution_id=query.node_execution_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return result.model_dump()
@console_ns.route(
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/download"
)
class WorkflowAgentWorkspaceDownloadResource(Resource):
@console_ns.doc("download_workflow_agent_workspace_file")
@console_ns.doc(description="Download a file from a Workflow Agent node's sandbox workspace (read-only)")
@console_ns.doc(
params={
"app_id": "Application ID",
"workflow_run_id": "Workflow run ID",
"node_id": "Workflow Agent node ID",
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
}
)
@console_ns.doc(produces=["application/octet-stream"])
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
@console_ns.response(413, "File exceeds the workspace download limit")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
try:
result = WorkflowAgentWorkspaceService().download(
tenant_id=tenant_id,
app_id=app_model.id,
workflow_run_id=str(workflow_run_id),
node_id=node_id,
node_execution_id=query.node_execution_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return _download_response(result)

View File

@ -25,9 +25,6 @@ from controllers.console.wraps import (
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
with_current_user_id,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
@ -37,8 +34,8 @@ from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.enums import WorkflowExecutionStatus
from libs.helper import build_icon_url, to_timestamp
from libs.login import login_required
from models import Account, App, DatasetPermissionEnum, Workflow
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_service import AppListParams, AppService, CreateAppParams
@ -58,7 +55,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"]
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
register_enum_models(console_ns, IconType)
@ -69,7 +66,7 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
@ -118,9 +115,7 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"] = Field(
..., description="App mode"
)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -398,8 +393,6 @@ class AppDetailWithSite(AppDetail):
max_active_requests: int | None = None
deleted_tools: list[DeletedTool] = Field(default_factory=list)
site: Site | None = None
# For Agent App type: the roster Agent backing this app (None otherwise).
bound_agent_id: str | None = None
@computed_field(return_type=str | None) # type: ignore
@property
@ -475,10 +468,10 @@ class AppListApi(Resource):
@account_initialization_required
@enterprise_license_required
@with_session(write=False)
@with_current_user_id
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
def get(self, session: Session):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
params = AppListParams(
page=args.page,
@ -491,7 +484,7 @@ class AppListApi(Resource):
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
@ -551,10 +544,9 @@ class AppListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
def post(self):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload)
params = CreateAppParams(
name=args.name,
@ -657,10 +649,11 @@ class AppCopyApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
@with_current_user
def post(self, current_user: Account, app_model: App):
def post(self, app_model: App):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine, expire_on_commit=False) as session:
@ -739,8 +732,7 @@ class AppPublishToCreatorsPlatformApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
@with_current_user_id
def post(self, current_user_id: str, app_model: App):
def post(self, app_model: App):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
@ -748,11 +740,13 @@ class AppPublishToCreatorsPlatformApi(Resource):
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(current_user_id, claim_code)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}

View File

@ -4,7 +4,7 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.fields import SimpleResultResponse
@ -19,12 +19,7 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -46,24 +41,9 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_debugger_chat_streaming(
*, app_mode: AppMode, response_mode: str, response_mode_provided: bool = True
) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode != "blocking"
if response_mode_provided and response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
# Agent Apps (AppMode.AGENT) derive their model + prompt from the bound Agent
# Soul, so no override ``model_config`` is sent; chat / agent-chat / completion
# debugging still pass it. Optional here, required in practice by those modes
# downstream when their config is built from args.
model_config_data: dict[str, Any] = Field(default_factory=dict, alias="model_config")
model_config_data: dict[str, Any] = Field(..., alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
@ -151,13 +131,14 @@ class CompletionMessageStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@with_current_user_id
def post(self, current_user_id: str, app_model: App, task_id: str):
def post(self, app_model: App, task_id: str):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user_id,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
@ -176,20 +157,13 @@ class ChatMessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model: App):
raw_payload = console_ns.payload or {}
args_model = ChatMessagePayload.model_validate(raw_payload)
args_model = ChatMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
streaming = _resolve_debugger_chat_streaming(
app_mode=AppMode.value_of(app_model.mode),
response_mode=args_model.response_mode,
response_mode_provided=isinstance(raw_payload, dict) and "response_mode" in raw_payload,
)
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
args["response_mode"] = "streaming"
streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)
@ -237,14 +211,15 @@ class ChatMessageStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user_id
def post(self, current_user_id: str, app_model: App, task_id: str):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model: App, task_id: str):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user_id,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)

View File

@ -212,7 +212,7 @@ class ChatConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
@with_current_user
def get(self, current_user: Account, app_model: App):
@ -323,7 +323,7 @@ class ChatConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
@with_current_user
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
@ -340,7 +340,7 @@ class ChatConversationDetailApi(Resource):
@console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@edit_permission_required
@with_current_user

View File

@ -180,7 +180,7 @@ class ChatMessageListApi(Resource):
@login_required
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
def get(self, app_model: App):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
@ -337,7 +337,7 @@ class MessageSuggestedQuestionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@with_current_user
def get(self, current_user: Account, app_model: App, message_id: UUID):
message_id_str = str(message_id)

View File

@ -8,20 +8,14 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.model import App, AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@ -58,10 +52,9 @@ class ModelConfigResource(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
@with_current_user_id
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
def post(self, app_model: App):
"""Modify app model config"""
current_user, current_tenant_id = current_account_with_tenant()
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_tenant_id,
@ -71,8 +64,8 @@ class ModelConfigResource(Resource):
new_app_model_config = AppModelConfig(
app_id=app_model.id,
created_by=current_user_id,
updated_by=current_user_id,
created_by=current_user.id,
updated_by=current_user.id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
@ -97,7 +90,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user_id,
user_id=current_user.id,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_tenant_id,
@ -137,7 +130,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user_id,
user_id=current_user.id,
)
except Exception:
continue
@ -174,7 +167,7 @@ class ModelConfigResource(Resource):
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user_id
app_model.updated_by = current_user.id
app_model.updated_at = naive_utc_now()
db.session.commit()

View File

@ -290,7 +290,7 @@ class AverageSessionInteractionStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, Concatenate, TypedDict
from typing import Any, TypedDict
from uuid import UUID
from flask import Response, request
@ -214,9 +214,7 @@ workflow_draft_variable_list_model = console_ns.model(
)
def _api_prerequisite[T, **P, R](
f: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R | Response]:
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -233,8 +231,8 @@ def _api_prerequisite[T, **P, R](
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(self, *args, **kwargs)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(*args, **kwargs)
return wrapper

View File

@ -8,16 +8,9 @@ from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -39,9 +32,8 @@ class Subscription(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@ -53,9 +45,8 @@ class Invoices(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email, current_tenant_id)
@ -72,8 +63,9 @@ class PartnerTenants(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
@with_current_user
def put(self, current_user: Account, partner_key: str):
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
try:
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id

View File

@ -3,18 +3,11 @@ from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from .. import console_ns
from ..wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
class ComplianceDownloadQuery(BaseModel):
@ -36,9 +29,8 @@ class ComplianceApi(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
ip_address = extract_remote_ip(request)

View File

@ -1,37 +1,41 @@
import json
from collections.abc import Generator
from datetime import datetime
from typing import Any, Literal, cast
from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_serializer
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.fields import SimpleResultResponse, TextContentResponse
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_model
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.entities.knowledge_entities import IndexingEstimate
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.data_source_fields import (
integrate_fields,
integrate_icon_fields,
integrate_list_fields,
integrate_notion_info_list_fields,
integrate_page_fields,
integrate_workspace_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.helper import dump_response, to_timestamp
from libs.login import login_required
from models import Account, DataSourceOauthBinding, Document
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required, with_current_tenant_id, with_current_user
from ..wraps import account_initialization_required, setup_required
class NotionEstimatePayload(BaseModel):
@ -50,74 +54,50 @@ class DataSourceNotionPreviewQuery(BaseModel):
credential_id: str = Field(..., description="Credential ID", min_length=1)
class DataSourceIntegrateIconResponse(ResponseModel):
type: str | None = None
url: str | None = None
emoji: str | None = None
register_schema_model(console_ns, NotionEstimatePayload)
register_response_schema_models(console_ns, SimpleResultResponse, TextContentResponse)
class DataSourceIntegratePageResponse(ResponseModel):
page_name: str
page_id: str
page_icon: DataSourceIntegrateIconResponse | None
parent_id: str
type: str
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
integrate_page_fields_copy = integrate_page_fields.copy()
integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
class DataSourceIntegrateWorkspaceResponse(ResponseModel):
workspace_name: str | None
workspace_id: str | None
workspace_icon: str | None
pages: list[DataSourceIntegratePageResponse]
total: int
integrate_workspace_fields_copy = integrate_workspace_fields.copy()
integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
integrate_fields_copy = integrate_fields.copy()
integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
class DataSourceIntegrateResponse(ResponseModel):
id: str | None
provider: str
created_at: datetime | int | None
is_bound: bool
disabled: bool | None
link: str
source_info: DataSourceIntegrateWorkspaceResponse | None
integrate_list_fields_copy = integrate_list_fields.copy()
integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
@field_serializer("created_at")
def serialize_created_at(self, value: datetime | int | None) -> int | None:
return to_timestamp(value)
notion_page_fields = {
"page_name": fields.String,
"page_id": fields.String,
"page_icon": fields.Nested(integrate_icon_model, allow_null=True),
"is_bound": fields.Boolean,
"parent_id": fields.String,
"type": fields.String,
}
notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
notion_workspace_fields = {
"workspace_name": fields.String,
"workspace_id": fields.String,
"workspace_icon": fields.String,
"pages": fields.List(fields.Nested(notion_page_model)),
}
notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
class DataSourceIntegrateListResponse(ResponseModel):
data: list[DataSourceIntegrateResponse]
class NotionIntegratePageResponse(ResponseModel):
page_name: str
page_id: str
page_icon: DataSourceIntegrateIconResponse | None
parent_id: str | None
type: str
is_bound: bool
class NotionIntegrateWorkspaceResponse(ResponseModel):
workspace_name: str | None
workspace_id: str | None
workspace_icon: str | None
pages: list[NotionIntegratePageResponse]
class NotionIntegrateInfoListResponse(ResponseModel):
notion_info: list[NotionIntegrateWorkspaceResponse]
register_schema_models(console_ns, NotionEstimatePayload)
register_response_schema_models(
console_ns,
DataSourceIntegrateListResponse,
IndexingEstimate,
NotionIntegrateInfoListResponse,
SimpleResultResponse,
TextContentResponse,
integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
integrate_notion_info_list_model = get_or_create_model(
"NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
)
@ -129,9 +109,10 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[DataSourceIntegrateListResponse.__name__])
@with_current_tenant_id
def get(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
@marshal_with(integrate_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
# get workspace data source integrates
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
@ -173,21 +154,19 @@ class DataSourceApi(Resource):
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
}
)
return dump_response(DataSourceIntegrateListResponse, {"data": integrate_data}), 200
return {"data": integrate_data}, 200
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_tenant_id
def patch(
self, current_tenant_id: str, binding_id: UUID, action: Literal["enable", "disable"]
) -> tuple[dict[str, str], int]:
binding_id_str = str(binding_id)
def patch(self, binding_id, action: Literal["enable", "disable"]):
_, current_tenant_id = current_account_with_tenant()
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.id == binding_id_str, DataSourceOauthBinding.tenant_id == current_tenant_id
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
)
).scalar_one_or_none()
if data_source_binding is None:
@ -219,12 +198,12 @@ class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.doc(params=query_params_from_model(DataSourceNotionListQuery))
@console_ns.response(200, "Success", console_ns.models[NotionIntegrateInfoListResponse.__name__])
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str, Any], int]:
query = DataSourceNotionListQuery.model_validate(request.args.to_dict(flat=True))
@marshal_with(integrate_notion_info_list_model)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
@ -299,22 +278,22 @@ class DataSourceNotionListApi(Resource):
pages.append(page_info)
except Exception as e:
raise e
notion_info = [{**workspace_info, "pages": pages}] if workspace_info else []
return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200
return {"notion_info": {**workspace_info, "pages": pages}}, 200
@console_ns.route("/notion/pages/<uuid:page_id>/<string:page_type>/preview")
class DataSourceNotionPreviewApi(Resource):
"""Preview one authorized Notion page through the datasource credential."""
@console_ns.route(
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.doc(params=query_params_from_model(DataSourceNotionPreviewQuery))
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
@with_current_tenant_id
def get(self, current_tenant_id: str, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict(flat=True))
def get(self, page_id: UUID, page_type: str):
_, current_tenant_id = current_account_with_tenant()
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
@ -337,18 +316,13 @@ class DataSourceNotionPreviewApi(Resource):
text_docs = extractor.extract()
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
@console_ns.route("/datasets/notion-indexing-estimate")
class DataSourceNotionIndexingEstimateApi(Resource):
"""Estimate indexing work for selected Notion pages."""
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[IndexingEstimate.__name__])
@with_current_tenant_id
def post(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
def post(self):
_, current_tenant_id = current_account_with_tenant()
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
# validate args
@ -381,7 +355,7 @@ class DataSourceNotionIndexingEstimateApi(Resource):
args["doc_form"],
args["doc_language"],
)
return dump_response(IndexingEstimate, response), 200
return response.model_dump(), 200
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
@ -390,7 +364,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID) -> tuple[dict[str, str], int]:
def get(self, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -408,7 +382,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID, document_id: UUID) -> tuple[dict[str, str], int]:
def get(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)

View File

@ -44,8 +44,8 @@ from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.datetime_utils import naive_utc_now
from libs.helper import dump_response, to_timestamp
from libs.login import login_required
from models import Account, DatasetProcessRule, Document, DocumentSegment, UploadFile
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.enums import IndexingStatus, SegmentStatus
from services.dataset_service import DatasetService, DocumentService
@ -71,8 +71,6 @@ from ..wraps import (
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
logger = logging.getLogger(__name__)
@ -171,9 +169,8 @@ register_response_schema_models(
class DocumentResource(Resource):
def get_document(
self, dataset_id: str, document_id: str, current_user: Account, current_tenant_id: str
) -> Document:
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -193,7 +190,8 @@ class DocumentResource(Resource):
return document
def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account) -> Sequence[Document]:
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
current_user, _ = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -220,8 +218,8 @@ class GetProcessRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, current_user: Account):
def get(self):
current_user, _ = current_account_with_tenant()
req_data = request.args
document_id = req_data.get("document_id")
@ -281,9 +279,8 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
def get(self, dataset_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
@ -408,8 +405,8 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -483,10 +480,9 @@ class DatasetInitApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
def post(self):
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
@ -543,12 +539,11 @@ class DocumentIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id: UUID, document_id: UUID):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
document = self.get_document(dataset_id_str, document_id_str)
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
@ -609,11 +604,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, batch: str):
def get(self, dataset_id: UUID, batch: str):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
documents = self.get_batch_documents(dataset_id_str, batch)
if not documents:
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule
@ -710,10 +704,9 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, current_user: Account, dataset_id: UUID, batch: str):
def get(self, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
documents = self.get_batch_documents(dataset_id_str, batch)
documents_status = []
for document in documents:
completed_segments = (
@ -766,18 +759,16 @@ class DocumentIndexingStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
document = self.get_document(dataset_id_str, document_id_str)
completed_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == document_id_str,
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -786,7 +777,7 @@ class DocumentIndexingStatusApi(DocumentResource):
total_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == document_id_str,
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -829,12 +820,10 @@ class DocumentApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
document = self.get_document(dataset_id_str, document_id_str)
metadata = request.args.get("metadata", "all")
if metadata not in self.METADATA_CHOICES:
@ -920,9 +909,7 @@ class DocumentApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Document deleted successfully")
@with_current_user
@with_current_tenant_id
def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
def delete(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -931,7 +918,7 @@ class DocumentApi(DocumentResource):
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
document = self.get_document(dataset_id_str, document_id_str)
try:
DocumentService.delete_document(document)
@ -952,11 +939,9 @@ class DocumentDownloadApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
# Reuse the shared permission/tenant checks implemented in DocumentResource.
document = self.get_document(str(dataset_id), str(document_id), current_user, current_tenant_id)
document = self.get_document(str(dataset_id), str(document_id))
return {"url": DocumentService.get_document_download_url(document)}
@ -971,13 +956,12 @@ class DocumentBatchDownloadZipApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
def post(self, dataset_id: UUID):
"""Stream a ZIP archive containing the requested uploaded documents."""
# Parse and validate request payload.
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
@ -1019,19 +1003,11 @@ class DocumentProcessingApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@with_current_user
@with_current_tenant_id
def patch(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
action: Literal["pause", "resume"],
):
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
document = self.get_document(dataset_id_str, document_id_str)
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
@ -1075,12 +1051,11 @@ class DocumentMetadataApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def put(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
def put(self, dataset_id: UUID, document_id: UUID):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
document = self.get_document(dataset_id_str, document_id_str)
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
@ -1125,10 +1100,8 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_user
def patch(
self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]
):
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -1243,6 +1216,8 @@ class DocumentRetryApi(DocumentResource):
raise NotFound("Dataset not found.")
for document_id in payload.document_ids:
try:
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
@ -1273,9 +1248,9 @@ class DocumentRenameApi(DocumentResource):
@account_initialization_required
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
@with_current_user
def post(self, current_user: Account, dataset_id: UUID, document_id: UUID):
def post(self, dataset_id: UUID, document_id: UUID):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
@ -1298,9 +1273,9 @@ class WebsiteDocumentSyncApi(DocumentResource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_tenant_id
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id: UUID, document_id: UUID):
"""sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if not dataset:
@ -1376,8 +1351,7 @@ class DocumentGenerateSummaryApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
def post(self, dataset_id: UUID):
"""
Generate summary index for specified documents.
@ -1385,6 +1359,7 @@ class DocumentGenerateSummaryApi(Resource):
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# Get dataset
@ -1469,8 +1444,7 @@ class DocumentSummaryStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def get(self, current_user: Account, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id: UUID, document_id: UUID):
"""
Get summary index generation status for a document.
@ -1483,6 +1457,7 @@ class DocumentSummaryStatusApi(DocumentResource):
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)

View File

@ -33,8 +33,6 @@ from controllers.console.wraps import (
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
@ -53,8 +51,7 @@ from fields.segment_fields import (
)
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response, escape_like_pattern
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
@ -167,9 +164,9 @@ class DatasetDocumentSegmentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -277,8 +274,9 @@ class DatasetDocumentSegmentListApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
@console_ns.response(204, "Segments deleted successfully")
@with_current_user
def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID):
def delete(self, dataset_id: UUID, document_id: UUID):
current_user, _ = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -314,16 +312,9 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_user
@with_current_tenant_id
def patch(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
action: Literal["enable", "disable"],
):
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if not dataset:
@ -382,9 +373,9 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
def post(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -440,11 +431,9 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
@with_current_user
@with_current_tenant_id
def patch(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -511,11 +500,9 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@console_ns.response(204, "Segment deleted successfully")
@with_current_user
@with_current_tenant_id
def delete(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -561,9 +548,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
def post(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -632,11 +619,9 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
@with_current_user
@with_current_tenant_id
def post(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -692,8 +677,9 @@ class ChildChunkAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -745,11 +731,9 @@ class ChildChunkAddApi(Resource):
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
)
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
@with_current_user
@with_current_tenant_id
def patch(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -797,17 +781,9 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.response(204, "Child chunk deleted successfully")
@with_current_user
@with_current_tenant_id
def delete(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
segment_id: UUID,
child_chunk_id: UUID,
):
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -864,17 +840,9 @@ class ChildChunkUpdateApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
@with_current_user
@with_current_tenant_id
def patch(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
segment_id: UUID,
child_chunk_id: UUID,
):
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)

View File

@ -15,7 +15,6 @@ from controllers.console.wraps import (
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from fields.dataset_fields import (
dataset_detail_fields,
@ -30,8 +29,7 @@ from fields.dataset_fields import (
vector_setting_fields,
weighted_score_fields,
)
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
@ -154,9 +152,8 @@ class ExternalApiTemplateListApi(Resource):
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(payload.settings)
@ -185,8 +182,8 @@ class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
def get(self, external_knowledge_api_id: UUID):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
external_knowledge_api_id_str, current_tenant_id
@ -200,9 +197,8 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
@with_current_user
@with_current_tenant_id
def patch(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
def patch(self, external_knowledge_api_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
@ -221,9 +217,8 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(204, "External knowledge API deleted successfully")
@with_current_user
@with_current_tenant_id
def delete(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
def delete(self, external_knowledge_api_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
@ -242,8 +237,8 @@ class ExternalApiUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
def get(self, external_knowledge_api_id: UUID):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
@ -264,10 +259,9 @@ class ExternalDatasetCreateApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant()
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
@ -299,8 +293,8 @@ class ExternalKnowledgeHitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -9,18 +9,11 @@ from configs import dify_config
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.plugin.impl.oauth import OAuthHandler
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
@ -73,10 +66,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider_id: str):
def get(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = datasource_provider_id.provider_name
@ -180,8 +174,9 @@ class DatasourceAuth(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -200,17 +195,15 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, user: Account, provider_id: str):
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
user=user,
)
return {"result": datasources}, 200
@ -223,8 +216,9 @@ class DatasourceAuthDeleteApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
@ -247,8 +241,9 @@ class DatasourceAuthUpdateApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
@ -269,8 +264,9 @@ class DatasourceAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str):
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -281,8 +277,9 @@ class DatasourceHardCodeAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str):
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -295,8 +292,9 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -312,8 +310,9 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider_id: str):
def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
@ -331,8 +330,9 @@ class DatasourceAuthDefaultApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -352,8 +352,9 @@ class DatasourceUpdateProviderNameApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()

View File

@ -1,20 +1,13 @@
import logging
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.fields import SimpleDataResponse
from controllers.common.schema import (
JsonResponseWithStatus,
query_params_from_model,
register_response_schema_models,
register_schema_models,
)
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@ -23,132 +16,79 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import dump_response
from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger: logging.Logger = logging.getLogger(__name__)
class PipelineTemplateListQuery(BaseModel):
type: str = Field(default="built-in", description="Template source: built-in or customized")
language: str = Field(default="en-US", description="Template language")
class PipelineTemplateDetailQuery(BaseModel):
type: str = Field(default="built-in", description="Template source: built-in or customized")
class PipelineTemplateItemResponse(ResponseModel):
id: str
name: str
icon: dict[str, Any]
description: str
position: int
chunk_structure: str
copyright: str | None = None
privacy_policy: str | None = None
class PipelineTemplateListResponse(ResponseModel):
pipeline_templates: list[PipelineTemplateItemResponse]
class PipelineTemplateDetailResponse(ResponseModel):
id: str
name: str
icon_info: dict[str, Any]
description: str
chunk_structure: str
export_data: str
graph: dict[str, Any]
created_by: str | None = None
class CustomizedPipelineTemplatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", max_length=400)
icon_info: dict[str, object] = Field(default_factory=lambda: IconInfo(icon="").model_dump())
register_schema_models(
console_ns,
CustomizedPipelineTemplatePayload,
PipelineTemplateDetailQuery,
PipelineTemplateListQuery,
)
register_response_schema_models(
console_ns,
PipelineTemplateDetailResponse,
PipelineTemplateListResponse,
SimpleDataResponse,
)
logger = logging.getLogger(__name__)
@console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource):
@console_ns.doc(params=query_params_from_model(PipelineTemplateListQuery))
@console_ns.response(200, "Pipeline templates", console_ns.models[PipelineTemplateListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self) -> JsonResponseWithStatus:
query = PipelineTemplateListQuery.model_validate(request.args.to_dict(flat=True))
def get(self):
type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language)
return dump_response(PipelineTemplateListResponse, pipeline_templates), 200
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
return pipeline_templates, 200
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
class PipelineTemplateDetailApi(Resource):
@console_ns.doc(params=query_params_from_model(PipelineTemplateDetailQuery))
@console_ns.response(200, "Pipeline template", console_ns.models[PipelineTemplateDetailResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, template_id: str) -> JsonResponseWithStatus:
query = PipelineTemplateDetailQuery.model_validate(request.args.to_dict(flat=True))
def get(self, template_id: str):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, query.type)
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
if pipeline_template is None:
raise NotFound("Pipeline template not found from upstream service.")
return dump_response(PipelineTemplateDetailResponse, pipeline_template), 200
return {"error": "Pipeline template not found from upstream service."}, 404
return pipeline_template, 200
class Payload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", max_length=400)
icon_info: dict[str, object] | None = None
register_schema_models(console_ns, Payload)
register_response_schema_models(console_ns, SimpleDataResponse)
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
class CustomizedPipelineTemplateApi(Resource):
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
@console_ns.response(204, "Pipeline template updated")
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str) -> tuple[str, int]:
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
def patch(self, template_id: str):
payload = Payload.model_validate(console_ns.payload or {})
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return "", 204
return 200
@console_ns.response(204, "Pipeline template deleted")
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str) -> tuple[str, int]:
def delete(self, template_id: str):
RagPipelineService.delete_customized_pipeline_template(template_id)
return "", 204
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(200, "Success", console_ns.models[SimpleDataResponse.__name__])
def post(self, template_id: str) -> JsonResponseWithStatus:
def post(self, template_id: str):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
template = session.scalar(
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
@ -156,20 +96,19 @@ class CustomizedPipelineTemplateApi(Resource):
if not template:
raise ValueError("Customized pipeline template not found.")
return dump_response(SimpleDataResponse, {"data": template.yaml_content}), 200
return {"data": template.yaml_content}, 200
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource):
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
@console_ns.response(204, "Pipeline template published")
@console_ns.expect(console_ns.models[Payload.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str) -> tuple[str, int]:
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
def post(self, pipeline_id: str):
payload = Payload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
return "", 204
return {"result": "success"}

View File

@ -1,25 +1,20 @@
from flask_restx import Resource
from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
from controllers.common.schema import JsonResponseWithStatus, register_response_schema_models, register_schema_models
from controllers.common.schema import register_schema_model
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import RagPipelineImportResponse
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.dataset_fields import DatasetDetailResponse
from libs.helper import dump_response
from libs.login import login_required
from models import Account
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_account_with_tenant, login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -30,26 +25,19 @@ class RagPipelineDatasetImportPayload(BaseModel):
yaml_content: str
register_schema_models(console_ns, RagPipelineDatasetImportPayload)
register_response_schema_models(console_ns, DatasetDetailResponse, RagPipelineImportResponse)
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
@console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource):
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
@console_ns.response(
201,
"RAG pipeline dataset import started",
console_ns.models[RagPipelineImportResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
def post(self):
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
@ -82,20 +70,19 @@ class CreateRagPipelineDatasetApi(Resource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return dump_response(RagPipelineImportResponse, import_info), 201
return import_info, 201
@console_ns.route("/rag/pipeline/empty-dataset")
class CreateEmptyRagPipelineDatasetApi(Resource):
@console_ns.response(201, "RAG pipeline dataset created", console_ns.models[DatasetDetailResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
@ -112,4 +99,4 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
partial_member_list=None,
),
)
return dump_response(DatasetDetailResponse, dataset), 201
return marshal(dataset, dataset_detail_fields), 201

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Callable
from typing import Any, Concatenate, NoReturn
from typing import Any, NoReturn
from uuid import UUID
from flask import Response, request
@ -57,9 +57,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite[T, **P, R](
f: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R | Response]:
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -74,10 +72,10 @@ def _api_prerequisite[T, **P, R](
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(self, *args, **kwargs)
return f(*args, **kwargs)
return wrapper

View File

@ -1,15 +1,9 @@
from flask import request
from flask_restx import Resource
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.fields import SimpleDataResponse
from controllers.common.schema import (
JsonResponseWithStatus,
query_params_from_model,
register_response_schema_models,
register_schema_models,
)
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
@ -18,10 +12,12 @@ from controllers.console.wraps import (
setup_required,
with_current_user,
)
from core.plugin.entities.plugin import PluginDependency
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import dump_response
from fields.rag_pipeline_fields import (
leaked_dependency_fields,
pipeline_import_check_dependencies_fields,
pipeline_import_fields,
)
from libs.login import login_required
from models.account import Account
from models.dataset import Pipeline
@ -42,44 +38,34 @@ class RagPipelineImportPayload(BaseModel):
class IncludeSecretQuery(BaseModel):
include_secret: str = Field(default="false", description="Whether to include secret values in the exported DSL")
class RagPipelineImportResponse(ResponseModel):
id: str
status: ImportStatus
pipeline_id: str | None = None
dataset_id: str | None = None
current_dsl_version: str
imported_dsl_version: str
error: str = ""
class RagPipelineImportCheckDependenciesResponse(ResponseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
include_secret: str = Field(default="false")
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
register_response_schema_models(
console_ns,
RagPipelineImportCheckDependenciesResponse,
RagPipelineImportResponse,
SimpleDataResponse,
pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
fields.Nested(leaked_dependency_model)
)
pipeline_import_check_dependencies_model = get_or_create_model(
"RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
)
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
@console_ns.response(200, "Import completed", console_ns.models[RagPipelineImportResponse.__name__])
@console_ns.response(202, "Import pending confirmation", console_ns.models[RagPipelineImportResponse.__name__])
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
@with_current_user
def post(self, current_user: Account) -> JsonResponseWithStatus:
def post(self, current_user: Account):
# Check user role first
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
@ -107,23 +93,22 @@ class RagPipelineImportApi(Resource):
status = result.status
match status:
case ImportStatus.FAILED:
return dump_response(RagPipelineImportResponse, result), 400
return result.model_dump(mode="json"), 400
case ImportStatus.PENDING:
return dump_response(RagPipelineImportResponse, result), 202
return result.model_dump(mode="json"), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return dump_response(RagPipelineImportResponse, result), 200
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
class RagPipelineImportConfirmApi(Resource):
@console_ns.response(200, "Import confirmed", console_ns.models[RagPipelineImportResponse.__name__])
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_model)
@with_current_user
def post(self, current_user: Account, import_id: str) -> JsonResponseWithStatus:
def post(self, current_user: Account, import_id: str):
with Session(db.engine, expire_on_commit=False) as session:
import_service = RagPipelineDslService(session)
account = current_user
@ -135,40 +120,34 @@ class RagPipelineImportConfirmApi(Resource):
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
return dump_response(RagPipelineImportResponse, result), 400
return dump_response(RagPipelineImportResponse, result), 200
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
class RagPipelineImportCheckDependenciesApi(Resource):
@console_ns.response(
200,
"Dependencies checked",
console_ns.models[RagPipelineImportCheckDependenciesResponse.__name__],
)
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine, expire_on_commit=False) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
return dump_response(RagPipelineImportCheckDependenciesResponse, result), 200
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
class RagPipelineExportApi(Resource):
@console_ns.doc(params=query_params_from_model(IncludeSecretQuery))
@console_ns.response(200, "Pipeline exported", console_ns.models[SimpleDataResponse.__name__])
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
def get(self, pipeline: Pipeline):
# Add include_secret params
query = IncludeSecretQuery.model_validate(request.args.to_dict())
@ -178,4 +157,4 @@ class RagPipelineExportApi(Resource):
pipeline=pipeline, include_secret=query.include_secret == "true"
)
return dump_response(SimpleDataResponse, {"data": result}), 200
return {"data": result}, 200

View File

@ -18,7 +18,6 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user_id
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -136,18 +135,20 @@ class CompletionApi(InstalledAppResource):
)
class CompletionStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_user_id
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
def post(self, installed_app: InstalledApp, task_id: str):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user_id,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
@ -214,8 +215,7 @@ class ChatApi(InstalledAppResource):
)
class ChatStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@with_current_user_id
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
def post(self, installed_app: InstalledApp, task_id: str):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -223,10 +223,13 @@ class ChatStopApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user_id,
user_id=current_user.id,
app_mode=app_mode,
)

View File

@ -12,19 +12,14 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from libs.login import login_required
from models import Account, App, InstalledApp, RecommendedApp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from models.model import IconType
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
@ -136,10 +131,9 @@ class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
if query.app_id:
installed_apps = db.session.scalars(
@ -218,8 +212,7 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.scalar(
@ -228,6 +221,8 @@ class InstalledAppsListApi(Resource):
if recommended_app is None:
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.get(App, payload.app_id)
if app is None:
@ -267,8 +262,8 @@ class InstalledAppApi(InstalledAppResource):
"""
@console_ns.response(204, "App uninstalled successfully")
@with_current_tenant_id
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
def delete(self, installed_app: InstalledApp):
_, current_tenant_id = current_account_with_tenant()
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")

View File

@ -14,13 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
model_validate,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, model_validate, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
@ -29,8 +23,8 @@ from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import login_required
from models import Account, App
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
@ -54,8 +48,9 @@ class ConsoleHumanInputFormApi(Resource):
"""Console API for getting human input form definition."""
@staticmethod
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
"""Ensure a console form token resolves only inside the current tenant."""
def _ensure_console_access(form: Form):
_, current_tenant_id = current_account_with_tenant()
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@ -67,8 +62,7 @@ class ConsoleHumanInputFormApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str, form_token: str):
def get(self, form_token: str):
"""
Get human input form definition by form token.
@ -79,23 +73,15 @@ class ConsoleHumanInputFormApi(Resource):
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form, current_tenant_id)
self._ensure_console_access(form)
return _jsonify_form_definition(form)
@account_initialization_required
@login_required
@with_current_user
@with_current_tenant_id
@model_validate(HumanInputFormSubmitPayload)
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
def post(
self,
payload: HumanInputFormSubmitPayload,
current_tenant_id: str,
current_user: Account,
form_token: str,
):
def post(self, payload: HumanInputFormSubmitPayload, form_token: str):
"""
Submit human input form by form token.
@ -109,12 +95,14 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form, current_tenant_id)
self._ensure_console_access(form)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
# The type checker is not smart enought to validate the following invariant.
@ -138,9 +126,7 @@ class ConsoleWorkflowEventsApi(Resource):
@account_initialization_required
@login_required
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
def get(self, workflow_run_id: str):
"""
Get workflow execution events stream after resume.
@ -148,6 +134,8 @@ class ConsoleWorkflowEventsApi(Resource):
Returns Server-Sent Events stream.
"""
user, tenant_id = current_account_with_tenant()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(

View File

@ -13,7 +13,7 @@ from controllers.common.errors import (
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import with_current_user
from core.file import remote_fetcher
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource):
@login_required
def get(self, url: str):
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = remote_fetcher.make_request("HEAD", decoded_url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
@ -58,9 +58,9 @@ class RemoteFileUpload(Resource):
# Try to fetch remote file metadata/content first
try:
resp = remote_fetcher.make_request("HEAD", url=url)
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
# Normalize into a user-friendly error message expected by tests
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
@ -74,7 +74,7 @@ class RemoteFileUpload(Resource):
raise FileTooLargeError()
# Load content if needed
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
upload_file = FileService(db.engine).upload_file(

View File

@ -0,0 +1,160 @@
import uuid
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
class SnippetListQuery(BaseModel):
"""Query parameters for listing snippets."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
is_published: bool | None = Field(default=None, description="Filter by published status")
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
@field_validator("creators", mode="before")
@classmethod
def parse_creators(cls, value: object) -> list[str] | None:
"""Normalize creators filter from query string or list input."""
return cls._normalize_string_list(value)
@field_validator("tag_ids", mode="before")
@classmethod
def parse_tag_ids(cls, value: object) -> list[str] | None:
"""Normalize and validate tag IDs from query string or list input."""
items = cls._normalize_string_list(value)
if not items:
return None
try:
return [str(uuid.UUID(item)) for item in items]
except ValueError as exc:
raise ValueError("Invalid UUID format in tag_ids.") from exc
@staticmethod
def _normalize_string_list(value: object) -> list[str] | None:
if value is None:
return None
if isinstance(value, str):
return [item.strip() for item in value.split(",") if item.strip()] or None
if isinstance(value, list):
return [str(item).strip() for item in value if str(item).strip()] or None
return None
class IconInfo(BaseModel):
"""Icon information model."""
icon: str | None = None
icon_type: Literal["emoji", "image"] | None = None
icon_background: str | None = None
icon_url: str | None = None
class InputFieldDefinition(BaseModel):
"""Input field definition for snippet parameters."""
default: str | None = None
hint: bool | None = None
label: str | None = None
max_length: int | None = None
options: list[str] | None = None
placeholder: str | None = None
required: bool | None = None
type: str | None = None # e.g., "text-input"
class CreateSnippetPayload(BaseModel):
"""Payload for creating a new snippet."""
name: str = Field(..., min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
type: Literal["node", "group"] = "node"
icon_info: IconInfo | None = None
graph: dict[str, Any] | None = None
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
class UpdateSnippetPayload(BaseModel):
"""Payload for updating a snippet."""
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
icon_info: IconInfo | None = None
class SnippetDraftSyncPayload(BaseModel):
"""Payload for syncing snippet draft workflow."""
graph: dict[str, Any]
hash: str | None = None
conversation_variables: list[dict[str, Any]] | None = Field(
default=None,
description="Ignored. Snippet workflows do not persist conversation variables.",
)
input_fields: list[dict[str, Any]] | None = None
class SnippetWorkflowListQuery(BaseModel):
"""Query parameters for listing snippet published workflows."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
class WorkflowRunQuery(BaseModel):
"""Query parameters for workflow runs."""
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
class SnippetDraftRunPayload(BaseModel):
"""Payload for running snippet draft workflow."""
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class SnippetDraftNodeRunPayload(BaseModel):
"""Payload for running a single node in snippet draft workflow."""
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
class SnippetIterationNodeRunPayload(BaseModel):
"""Payload for running an iteration node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class SnippetLoopNodeRunPayload(BaseModel):
"""Payload for running a loop node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class PublishWorkflowPayload(BaseModel):
"""Payload for publishing snippet workflow."""
knowledge_base_setting: dict[str, Any] | None = None
class SnippetImportPayload(BaseModel):
"""Payload for importing snippet from DSL."""
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
name: str | None = Field(default=None, description="Override snippet name")
description: str | None = Field(default=None, description="Override snippet description")
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
class IncludeSecretQuery(BaseModel):
"""Query parameter for including secret variables in export."""
include_secret: str = Field(default="false", description="Whether to include secret variables")

View File

@ -0,0 +1,638 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
from pydantic import Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow import (
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
WorkflowPaginationResponse,
WorkflowResponse,
)
from controllers.console.snippets.payloads import (
PublishWorkflowPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetDraftSyncPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
SnippetWorkflowListQuery,
WorkflowRunQuery,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.workflow_run_fields import (
WorkflowRunDetailResponse,
WorkflowRunNodeExecutionListResponse,
WorkflowRunNodeExecutionResponse,
WorkflowRunPaginationResponse,
)
from graphon.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.snippet import CustomizedSnippet
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.snippet_generate_service import SnippetGenerateService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Register Pydantic models with Swagger
class SnippetWorkflowResponse(WorkflowResponse):
input_fields: list[dict] = Field(default_factory=list)
register_schema_models(
console_ns,
SnippetDraftSyncPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
SnippetWorkflowListQuery,
WorkflowRunQuery,
PublishWorkflowPayload,
)
register_response_schema_models(
console_ns,
SnippetWorkflowResponse,
WorkflowPaginationResponse,
WorkflowRunPaginationResponse,
WorkflowRunDetailResponse,
WorkflowRunNodeExecutionListResponse,
WorkflowRunNodeExecutionResponse,
)
class SnippetNotFoundError(Exception):
"""Snippet not found error."""
pass
def get_snippet(view_func: Callable[P, R]):
"""Decorator to fetch and validate snippet access."""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("snippet_id"):
raise ValueError("missing snippet_id in path parameters")
_, current_tenant_id = current_account_with_tenant()
snippet_id = str(kwargs.get("snippet_id"))
del kwargs["snippet_id"]
snippet = SnippetService.get_snippet_by_id(
snippet_id=snippet_id,
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
kwargs["snippet"] = snippet
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
class SnippetDraftWorkflowApi(Resource):
@console_ns.doc("get_snippet_draft_workflow")
@console_ns.response(200, "Draft workflow retrieved successfully", console_ns.models[SnippetWorkflowResponse.__name__])
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get draft workflow for snippet."""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise DraftWorkflowNotExist()
db.session.expunge(workflow)
workflow.conversation_variables = []
workflow.input_fields = snippet.input_fields_list
return SnippetWorkflowResponse.model_validate(workflow, from_attributes=True).model_dump(mode="json")
@console_ns.doc("sync_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
@console_ns.response(200, "Draft workflow synced successfully")
@console_ns.response(400, "Hash mismatch")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Sync draft workflow for snippet."""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
try:
snippet_service = SnippetService()
workflow = snippet_service.sync_draft_workflow(
snippet=snippet,
graph=payload.graph,
unique_hash=payload.hash,
account=current_user,
input_fields=payload.input_fields,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
class SnippetDraftConfigApi(Resource):
@console_ns.doc("get_snippet_draft_config")
@console_ns.response(200, "Draft config retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get snippet draft workflow configuration limits."""
return {
"parallel_depth_limit": 3,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
class SnippetPublishedWorkflowApi(Resource):
@console_ns.doc("get_snippet_published_workflow")
@console_ns.response(200, "Published workflow retrieved successfully", console_ns.models[SnippetWorkflowResponse.__name__])
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get published workflow for snippet."""
if not snippet.is_published:
return None
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
if not workflow:
return None
workflow.input_fields = snippet.input_fields_list
return SnippetWorkflowResponse.model_validate(workflow, from_attributes=True).model_dump(mode="json")
@console_ns.doc("publish_snippet_workflow")
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
@console_ns.response(200, "Workflow published successfully")
@console_ns.response(400, "No draft workflow found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Publish snippet workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = SnippetService()
with Session(db.engine) as session:
snippet = session.merge(snippet)
try:
workflow = snippet_service.publish_workflow(
session=session,
snippet=snippet,
account=current_user,
)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
class SnippetDefaultBlockConfigsApi(Resource):
@console_ns.doc("get_snippet_default_block_configs")
@console_ns.response(200, "Default block configs retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get default block configurations for snippet workflow."""
snippet_service = SnippetService()
return snippet_service.get_default_block_configs()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows")
class SnippetPublishedAllWorkflowApi(Resource):
@console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
@console_ns.doc("get_all_snippet_published_workflows")
@console_ns.doc(description="Get all published workflows for a snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Published workflows retrieved successfully", console_ns.models[WorkflowPaginationResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get all published workflow versions for snippet."""
args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
snippet_service = SnippetService()
with Session(db.engine) as session:
workflows, has_more = snippet_service.get_all_published_workflows(
session=session,
snippet=snippet,
page=args.page,
limit=args.limit,
)
return WorkflowPaginationResponse.model_validate(
{
"items": workflows,
"page": args.page,
"limit": args.limit,
"has_more": has_more,
},
from_attributes=True,
).model_dump(mode="json")
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/<string:workflow_id>/restore")
class SnippetDraftWorkflowRestoreApi(Resource):
@console_ns.doc("restore_snippet_workflow_to_draft")
@console_ns.doc(description="Restore a published snippet workflow version into the draft workflow")
@console_ns.doc(params={"snippet_id": "Snippet ID", "workflow_id": "Published workflow ID"})
@console_ns.response(200, "Workflow restored successfully")
@console_ns.response(400, "Source workflow must be published")
@console_ns.response(404, "Workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, workflow_id: str):
"""Restore a published snippet workflow version into the draft workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = SnippetService()
try:
workflow = snippet_service.restore_published_workflow_to_draft(
snippet=snippet,
workflow_id=workflow_id,
account=current_user,
)
except IsDraftWorkflowError as exc:
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
except WorkflowNotFoundError as exc:
raise NotFound(str(exc)) from exc
except ValueError as exc:
raise BadRequest(str(exc)) from exc
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
class SnippetWorkflowRunsApi(Resource):
@console_ns.doc("list_snippet_workflow_runs")
@console_ns.response(200, "Workflow runs retrieved successfully", console_ns.models[WorkflowRunPaginationResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_snippet
def get(self, snippet: CustomizedSnippet):
"""List workflow runs for snippet."""
query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": query.last_id,
"limit": query.limit,
}
snippet_service = SnippetService()
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json")
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
class SnippetWorkflowRunDetailApi(Resource):
@console_ns.doc("get_snippet_workflow_run_detail")
@console_ns.response(200, "Workflow run detail retrieved successfully", console_ns.models[WorkflowRunDetailResponse.__name__])
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
def get(self, snippet: CustomizedSnippet, run_id):
"""Get workflow run detail for snippet."""
run_id = str(run_id)
snippet_service = SnippetService()
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
raise NotFound("Workflow run not found")
return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json")
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
class SnippetWorkflowRunNodeExecutionsApi(Resource):
@console_ns.doc("list_snippet_workflow_run_node_executions")
@console_ns.response(
200,
"Node executions retrieved successfully",
console_ns.models[WorkflowRunNodeExecutionListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_snippet
def get(self, snippet: CustomizedSnippet, run_id):
"""List node executions for a workflow run."""
run_id = str(run_id)
snippet_service = SnippetService()
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
snippet=snippet,
run_id=run_id,
)
return WorkflowRunNodeExecutionListResponse.model_validate(
{"data": node_executions}, from_attributes=True
).model_dump(mode="json")
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
class SnippetDraftNodeRunApi(Resource):
@console_ns.doc("run_snippet_draft_node")
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
@console_ns.response(
200, "Node run completed successfully", console_ns.models[WorkflowRunNodeExecutionResponse.__name__]
)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a single node in snippet draft workflow.
Executes a specific node with provided inputs for single-step debugging.
Returns the node execution result including status, outputs, and timing.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
user_inputs = payload.inputs
# Get draft workflow for file parsing
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
workflow_node_execution = SnippetGenerateService.run_draft_node(
snippet=snippet,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query=payload.query,
files=files,
)
return WorkflowRunNodeExecutionResponse.model_validate(
workflow_node_execution, from_attributes=True
).model_dump(mode="json")
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
class SnippetDraftNodeLastRunApi(Resource):
@console_ns.doc("get_snippet_draft_node_last_run")
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.response(
200, "Node last run retrieved successfully", console_ns.models[WorkflowRunNodeExecutionResponse.__name__]
)
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
def get(self, snippet: CustomizedSnippet, node_id: str):
"""
Get the last run result for a specific node in snippet draft workflow.
Returns the most recent execution record for the given node,
including status, inputs, outputs, and timing information.
"""
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
node_exec = snippet_service.get_snippet_node_last_run(
snippet=snippet,
workflow=draft_workflow,
node_id=node_id,
)
if node_exec is None:
raise NotFound("Node last run not found")
return WorkflowRunNodeExecutionResponse.model_validate(node_exec, from_attributes=True).model_dump(mode="json")
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class SnippetDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_snippet_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow iteration node for snippet.
Iteration nodes execute their internal sub-graph multiple times over an input list.
Returns an SSE event stream with iteration progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate_single_iteration(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class SnippetDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_snippet_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow loop node for snippet.
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
Returns an SSE event stream with loop progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = SnippetGenerateService.generate_single_loop(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
class SnippetDraftWorkflowRunApi(Resource):
@console_ns.doc("run_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""
Run draft workflow for snippet.
Executes the snippet's draft workflow with the provided inputs
and returns an SSE event stream with execution progress and results.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate(
snippet=snippet,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
class SnippetWorkflowTaskStopApi(Resource):
@console_ns.doc("stop_snippet_workflow_task")
@console_ns.response(200, "Task stopped successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, task_id: str):
"""
Stop a running snippet workflow task.
Uses both the legacy stop flag mechanism and the graph engine
command channel for backward compatibility.
"""
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -0,0 +1,319 @@
"""
Snippet draft workflow variable APIs.
Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
(`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
"""
from collections.abc import Callable
from functools import wraps
from typing import Any, ParamSpec, TypeVar
from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.workflow_draft_variable import (
WorkflowDraftVariableListQuery,
WorkflowDraftVariableUpdatePayload,
_ensure_variable_access,
_file_access_controller,
validate_node_id,
workflow_draft_variable_list_model,
workflow_draft_variable_list_without_value_model,
workflow_draft_variable_model,
)
from controllers.console.snippets.snippet_workflow import get_snippet
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from graphon.variables.types import SegmentType
from libs.login import current_user, login_required
from models.snippet import CustomizedSnippet
from models.workflow import WorkflowDraftVariable
from services.snippet_service import SnippetService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
P = ParamSpec("P")
R = TypeVar("R")
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
)
def _ensure_snippet_draft_variable_row_allowed(
*,
variable: WorkflowDraftVariable,
variable_id: str,
) -> None:
"""Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
raise NotFoundError(description=f"variable not found, id={variable_id}")
def _snippet_draft_var_prerequisite(f: Callable[P, R]) -> Callable[P, R]:
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return f(*args, **kwargs)
return wrapper
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
class SnippetWorkflowVariableCollectionApi(Resource):
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.doc("get_snippet_workflow_variables")
@console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
@console_ns.response(
200,
"Workflow variables retrieved successfully",
workflow_draft_variable_list_without_value_model,
)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
snippet_service = SnippetService()
if snippet_service.get_draft_workflow(snippet=snippet) is None:
raise DraftWorkflowNotExist()
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(session=session)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=snippet.id,
page=args.page,
limit=args.limit,
user_id=current_user.id,
exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
)
return workflow_vars
@console_ns.doc("delete_snippet_workflow_variables")
@console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
@console_ns.response(204, "Workflow variables deleted successfully")
@_snippet_draft_var_prerequisite
def delete(self, snippet: CustomizedSnippet) -> Response:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
class SnippetNodeVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_node_variables")
@console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_model)
def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(session=session)
node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
return node_vars
@console_ns.doc("delete_snippet_node_variables")
@console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
@console_ns.response(204, "Node variables deleted successfully")
@_snippet_draft_var_prerequisite
def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
class SnippetVariableApi(Resource):
@console_ns.doc("get_snippet_workflow_variable")
@console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
return variable
@console_ns.doc("update_snippet_workflow_variable")
@console_ns.doc(description="Update a draft workflow variable (snippet scope)")
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_model)
def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
new_name = args_model.name
raw_value = args_model.value
if new_name is None and raw_value is None:
return variable
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=snippet.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=snippet.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@console_ns.doc("delete_snippet_workflow_variable")
@console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
@console_ns.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class SnippetVariableResetApi(Resource):
@console_ns.doc("reset_snippet_workflow_variable")
@console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
@console_ns.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if draft_workflow is None:
raise NotFoundError(
f"Draft workflow not found, snippet_id={snippet.id}",
)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
if resetted is None:
return Response("", 204)
return marshal(resetted, workflow_draft_variable_model)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
class SnippetConversationVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_conversation_variables")
@console_ns.doc(
description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
)
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_model)
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
return WorkflowDraftVariableList(variables=[])
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
class SnippetSystemVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_system_variables")
@console_ns.doc(
description="System variables are not used in snippet workflows; returns an empty list for API parity"
)
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_model)
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
return WorkflowDraftVariableList(variables=[])
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
class SnippetEnvironmentVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_environment_variables")
@console_ns.doc(description="Get environment variables from snippet draft workflow graph")
@console_ns.response(200, "Environment variables retrieved successfully")
@console_ns.response(404, "Draft workflow not found")
@_snippet_draft_var_prerequisite
def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if workflow is None:
raise DraftWorkflowNotExist()
env_vars_list: list[dict[str, Any]] = []
for v in workflow.environment_variables:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.exposed_type().value,
"value": v.value,
"edited": False,
"visible": True,
"editable": True,
}
)
return {"items": env_vars_list}

View File

@ -51,7 +51,7 @@ class TagBindingRemovePayload(BaseModel):
class TagListQueryParam(BaseModel):
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
type: Literal["knowledge", "app", "snippet", ""] = Field("", description="Tag type filter")
keyword: str | None = Field(None, description="Search keyword")
@ -96,7 +96,10 @@ class TagListApi(Resource):
@login_required
@account_initialization_required
@console_ns.doc(
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
params={
"type": 'Tag type filter. Can be "knowledge", "app", or "snippet".',
"keyword": "Search keyword for tag name.",
}
)
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
@with_current_tenant_id

View File

@ -18,7 +18,7 @@ from controllers.common.fields import (
SimpleResultResponse,
VerificationTokenResponse,
)
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
@ -42,17 +42,15 @@ from controllers.console.wraps import (
enterprise_license_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.member_fields import Account as AccountResponse
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
from libs.login import login_required
from models import Account, AccountIntegrate, InvitationCode
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
from models.enums import CreatorUserRole
from models.model import UploadFile
@ -175,6 +173,7 @@ class CheckEmailUniquePayload(BaseModel):
register_schema_models(
console_ns,
AccountResponse,
AccountInitPayload,
AccountNamePayload,
AccountAvatarPayload,
@ -246,7 +245,6 @@ register_schema_models(
)
register_response_schema_models(
console_ns,
AccountResponse,
AvatarUrlResponse,
SimpleResultDataResponse,
SimpleResultResponse,
@ -260,8 +258,9 @@ class AccountInitApi(Resource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@setup_required
@login_required
@with_current_user
def post(self, account: Account):
def post(self):
account, _ = current_account_with_tenant()
if account.status == "active":
raise AccountAlreadyInitedError()
@ -307,8 +306,8 @@ class AccountProfileApi(Resource):
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@enterprise_license_required
@with_current_user
def get(self, current_user: Account):
def get(self):
current_user, _ = current_account_with_tenant()
return _serialize_account(current_user)
@ -319,8 +318,8 @@ class AccountNameApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
@ -330,21 +329,20 @@ class AccountNameApi(Resource):
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
@console_ns.doc("get_account_avatar")
@console_ns.doc(description="Get account avatar url")
@console_ns.doc(params=query_params_from_model(AccountAvatarQuery))
@console_ns.response(200, "Success", console_ns.models[AvatarUrlResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True))
avatar = args.avatar
if avatar.startswith(("http://", "https://")):
return dump_response(AvatarUrlResponse, {"avatar_url": avatar})
return {"avatar_url": avatar}
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
if upload_file is None:
@ -357,15 +355,15 @@ class AccountAvatarApi(Resource):
raise NotFound("Avatar file not found")
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return dump_response(AvatarUrlResponse, {"avatar_url": avatar_url})
return {"avatar_url": avatar_url}
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountAvatarPayload.model_validate(payload)
@ -381,8 +379,8 @@ class AccountInterfaceLanguageApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountInterfaceLanguagePayload.model_validate(payload)
@ -398,8 +396,8 @@ class AccountInterfaceThemeApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountInterfaceThemePayload.model_validate(payload)
@ -415,8 +413,8 @@ class AccountTimezoneApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountTimezonePayload.model_validate(payload)
@ -432,8 +430,8 @@ class AccountPasswordApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountPasswordPayload.model_validate(payload)
@ -451,8 +449,9 @@ class AccountIntegrateApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
@with_current_user
def get(self, account: Account):
def get(self):
account, _ = current_account_with_tenant()
account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
).all()
@ -496,8 +495,9 @@ class AccountDeleteVerifyApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
@with_current_user
def get(self, account: Account):
def get(self):
account, _ = current_account_with_tenant()
token, code = AccountService.generate_account_deletion_verification_code(account)
AccountService.send_account_deletion_verification_email(account, code)
@ -511,8 +511,9 @@ class AccountDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, account: Account):
def post(self):
account, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountDeletePayload.model_validate(payload)
@ -546,8 +547,9 @@ class EducationVerifyApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
@with_current_user
def get(self, account: Account):
def get(self):
account, _ = current_account_with_tenant()
return EducationVerifyResponse.model_validate(
BillingService.EducationIdentity.verify(account.id, account.email) or {}
).model_dump(mode="json")
@ -561,8 +563,9 @@ class EducationApi(Resource):
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@with_current_user
def post(self, account: Account):
def post(self):
account, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = EducationActivatePayload.model_validate(payload)
@ -574,8 +577,9 @@ class EducationApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
@with_current_user
def get(self, account: Account):
def get(self):
account, _ = current_account_with_tenant()
res = BillingService.EducationIdentity.status(account.id) or {}
# convert expire_at to UTC timestamp from isoformat
if res and "expire_at" in res:
@ -609,8 +613,8 @@ class ChangeEmailSendEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = ChangeEmailSendPayload.model_validate(payload)
@ -669,8 +673,8 @@ class ChangeEmailCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
@ -716,8 +720,7 @@ class ChangeEmailResetApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
normalized_new_email = args.new_email.lower()
@ -728,6 +731,7 @@ class ChangeEmailResetApi(Resource):
if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
current_user, _ = current_account_with_tenant()
reset_data = AccountService.get_change_email_data(args.token)
if not reset_data:
raise InvalidTokenError()

View File

@ -1,15 +1,9 @@
from flask_restx import Resource, fields
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, setup_required
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.agent_service import AgentService
@ -25,10 +19,14 @@ class AgentProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
user_id = user.id
tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
@ -44,7 +42,6 @@ class AgentProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
def get(self, provider_name: str):
current_user, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))

View File

@ -14,16 +14,10 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.plugin.impl.exc import PluginPermissionDeniedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService
@ -102,15 +96,17 @@ register_schema_models(
)
def _create_endpoint(tenant_id: str, user_id: str) -> dict[str, bool]:
"""Create a plugin endpoint for the injected workspace and user."""
def _create_endpoint() -> dict[str, bool]:
"""Create a plugin endpoint for the current workspace."""
user, tenant_id = current_account_with_tenant()
args = EndpointCreatePayload.model_validate(console_ns.payload)
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user_id,
user_id=user.id,
plugin_unique_identifier=args.plugin_unique_identifier,
name=args.name,
settings=args.settings,
@ -120,14 +116,16 @@ def _create_endpoint(tenant_id: str, user_id: str) -> dict[str, bool]:
raise ValueError(e.description) from e
def _update_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
"""Update a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
args = EndpointUpdatePayload.model_validate(console_ns.payload)
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=args.name,
settings=args.settings,
@ -135,12 +133,14 @@ def _update_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str
}
def _delete_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
"""Delete a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
return {
"success": EndpointService.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
user_id=user.id,
endpoint_id=endpoint_id,
)
}
@ -163,10 +163,8 @@ class EndpointCollectionApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
def post(self):
return _create_endpoint()
@console_ns.route("/workspaces/current/endpoints/create")
@ -191,10 +189,8 @@ class DeprecatedEndpointCreateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
def post(self):
return _create_endpoint()
@console_ns.route("/workspaces/current/endpoints/list")
@ -210,9 +206,9 @@ class EndpointListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def get(self, tenant_id: str, user_id: str):
def get(self):
user, tenant_id = current_account_with_tenant()
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True))
page = args.page
@ -222,7 +218,7 @@ class EndpointListApi(Resource):
{
"endpoints": EndpointService.list_endpoints(
tenant_id=tenant_id,
user_id=user_id,
user_id=user.id,
page=page,
page_size=page_size,
)
@ -243,9 +239,9 @@ class EndpointListForSinglePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def get(self, tenant_id: str, user_id: str):
def get(self):
user, tenant_id = current_account_with_tenant()
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True))
page = args.page
@ -256,7 +252,7 @@ class EndpointListForSinglePluginApi(Resource):
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=tenant_id,
user_id=user_id,
user_id=user.id,
plugin_id=plugin_id,
page=page,
page_size=page_size,
@ -282,10 +278,8 @@ class EndpointItemApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def delete(self, tenant_id: str, user_id: str, id: str):
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
def delete(self, id: str):
return _delete_endpoint(endpoint_id=id)
@console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint")
@ -301,10 +295,8 @@ class EndpointItemApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def patch(self, tenant_id: str, user_id: str, id: str):
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
def patch(self, id: str):
return _update_endpoint(endpoint_id=id)
@console_ns.route("/workspaces/current/endpoints/delete")
@ -330,11 +322,9 @@ class DeprecatedEndpointDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
def post(self):
args = EndpointIdPayload.model_validate(console_ns.payload)
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
return _delete_endpoint(endpoint_id=args.endpoint_id)
@console_ns.route("/workspaces/current/endpoints/update")
@ -360,11 +350,9 @@ class DeprecatedEndpointUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
def post(self):
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
return _update_endpoint(endpoint_id=args.endpoint_id)
@console_ns.route("/workspaces/current/endpoints/enable")
@ -382,14 +370,14 @@ class EndpointEnableApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.enable_endpoint(
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
}
@ -409,13 +397,13 @@ class EndpointDisableApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.disable_endpoint(
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
}

View File

@ -4,16 +4,11 @@ from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, setup_required
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import login_required
from models import Account, TenantAccountRole
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@ -34,9 +29,8 @@ class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str):
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
@ -78,9 +72,8 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
def post(self, provider: str, config_id: str):
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()

View File

@ -8,19 +8,12 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@ -102,8 +95,10 @@ class ModelProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str):
def get(self):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
payload = request.args.to_dict(flat=True)
args = ParserModelList.model_validate(payload)
@ -119,8 +114,9 @@ class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
payload = request.args.to_dict(flat=True)
args = ParserCredentialId.model_validate(payload)
@ -137,8 +133,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialCreate.model_validate(payload)
@ -161,8 +157,9 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def put(self, current_tenant_id: str, provider: str):
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialUpdate.model_validate(payload)
@ -187,8 +184,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider: str):
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialDelete.model_validate(payload)
@ -208,8 +205,8 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialSwitch.model_validate(payload)
@ -228,8 +225,8 @@ class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = ParserCredentialValidate.model_validate(payload)
@ -283,8 +280,11 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
payload = console_ns.payload or {}
args = ParserPreferredProviderType.model_validate(payload)
@ -301,11 +301,10 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider: str):
def get(self, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
data = BillingService.get_model_provider_payment_link(
provider_name=provider,

View File

@ -13,14 +13,12 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import login_required
from models import Account
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@ -195,7 +193,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
def get(self, tenant_id: str, provider):
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@ -271,9 +269,8 @@ class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
def get(self, tenant_id: str, provider: str):
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
model_provider_service = ModelProviderService()
@ -295,13 +292,9 @@ class ModelProviderModelCredentialApi(Resource):
)
if args.config_from == "predefined-model":
# Only the predefined-model branch needs visibility filtering by user.
# The account is injected once by the handler and only passed into the
# service branch that needs user-scoped credential visibility.
available_credentials = model_provider_service.get_provider_available_credentials(
tenant_id=tenant_id,
provider=provider,
user=user,
)
else:
available_credentials = model_provider_service.get_provider_model_available_credentials(

View File

@ -0,0 +1,407 @@
import logging
import re
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, marshal
from sqlalchemy.orm import Session
from werkzeug.datastructures import MultiDict
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.snippets.payloads import (
CreateSnippetPayload,
IncludeSecretQuery,
SnippetImportPayload,
SnippetListQuery,
UpdateSnippetPayload,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
from libs.login import current_account_with_tenant, login_required
from models.snippet import SnippetType
from services.app_dsl_service import ImportStatus
from services.snippet_dsl_service import SnippetDslService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
def _normalize_snippet_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
normalized: dict[str, str | list[str]] = {}
indexed_tag_ids: list[tuple[int, str]] = []
for key in query_args:
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
if match:
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
continue
value = query_args.get(key)
if value is not None:
normalized[key] = value
if indexed_tag_ids:
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
return normalized
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetListQuery,
CreateSnippetPayload,
UpdateSnippetPayload,
SnippetImportPayload,
IncludeSecretQuery,
)
# Create namespace models for marshaling
snippet_model = console_ns.model("Snippet", snippet_fields)
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
@console_ns.route("/workspaces/current/customized-snippets")
class CustomizedSnippetsApi(Resource):
@console_ns.doc("list_customized_snippets")
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List customized snippets with pagination and search."""
_, current_tenant_id = current_account_with_tenant()
query = SnippetListQuery.model_validate(_normalize_snippet_list_query_args(request.args))
snippets, total, has_more = SnippetService.get_snippets(
tenant_id=current_tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
is_published=query.is_published,
creators=query.creators,
tag_ids=query.tag_ids,
)
return {
"data": marshal(snippets, snippet_list_fields),
"page": query.page,
"limit": query.limit,
"total": total,
"has_more": has_more,
}, 200
@console_ns.doc("create_customized_snippet")
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
@console_ns.response(201, "Snippet created successfully", snippet_model)
@console_ns.response(400, "Invalid request")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Create a new customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
try:
snippet_type = SnippetType(payload.type)
except ValueError:
snippet_type = SnippetType.NODE
try:
if payload.graph is not None:
SnippetService.validate_snippet_graph_forbidden_nodes(payload.graph)
snippet = SnippetService.create_snippet(
tenant_id=current_tenant_id,
name=payload.name,
description=payload.description,
snippet_type=snippet_type,
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
account=current_user,
)
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 201
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
class CustomizedSnippetDetailApi(Resource):
@console_ns.doc("get_customized_snippet")
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
def get(self, snippet_id: str):
"""Get customized snippet details."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
return marshal(snippet, snippet_fields), 200
@console_ns.doc("update_customized_snippet")
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
@console_ns.response(200, "Snippet updated successfully", snippet_model)
@console_ns.response(400, "Invalid request")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, snippet_id: str):
"""Update customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if "icon_info" in update_data and update_data["icon_info"] is not None:
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
if not update_data:
return {"message": "No valid fields to update"}, 400
try:
with Session(db.engine, expire_on_commit=False) as session:
snippet = session.merge(snippet)
snippet = SnippetService.update_snippet(
session=session,
snippet=snippet,
account_id=current_user.id,
data=update_data,
)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 200
@console_ns.doc("delete_customized_snippet")
@console_ns.response(204, "Snippet deleted successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, snippet_id: str):
"""Delete customized snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.delete_snippet(
session=session,
snippet=snippet,
)
session.commit()
return "", 204
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
class CustomizedSnippetExportApi(Resource):
@console_ns.doc("export_customized_snippet")
@console_ns.doc(description="Export snippet configuration as DSL")
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
@console_ns.response(200, "Snippet exported successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Export snippet as DSL."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
# Get include_secret parameter
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine) as session:
export_service = SnippetDslService(session)
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
# Set filename with .snippet extension
filename = f"{snippet.name}.snippet"
encoded_filename = quote(filename)
response = Response(
result,
mimetype="application/x-yaml",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/x-yaml"
return response
@console_ns.route("/workspaces/current/customized-snippets/imports")
class CustomizedSnippetImportApi(Resource):
@console_ns.doc("import_customized_snippet")
@console_ns.doc(description="Import snippet from DSL")
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
@console_ns.response(200, "Snippet imported successfully")
@console_ns.response(202, "Import pending confirmation")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Import snippet from DSL."""
current_user, _ = current_account_with_tenant()
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.import_snippet(
account=current_user,
import_mode=payload.mode,
yaml_content=payload.yaml_content,
yaml_url=payload.yaml_url,
snippet_id=payload.snippet_id,
name=payload.name,
description=payload.description,
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
class CustomizedSnippetImportConfirmApi(Resource):
@console_ns.doc("confirm_snippet_import")
@console_ns.doc(description="Confirm a pending snippet import")
@console_ns.doc(params={"import_id": "Import ID to confirm"})
@console_ns.response(200, "Import confirmed successfully")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
"""Confirm a pending snippet import."""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.confirm_import(import_id=import_id, account=current_user)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
class CustomizedSnippetCheckDependenciesApi(Resource):
@console_ns.doc("check_snippet_dependencies")
@console_ns.doc(description="Check dependencies for a snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Dependencies checked successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Check dependencies for a snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.check_dependencies(snippet=snippet)
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
class CustomizedSnippetUseCountIncrementApi(Resource):
@console_ns.doc("increment_snippet_use_count")
@console_ns.doc(description="Increment snippet use count by 1")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Use count incremented successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, snippet_id: str):
"""Increment snippet use count when it is inserted into a workflow."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.increment_use_count(session=session, snippet=snippet)
session.commit()
session.refresh(snippet)
return {"result": "success", "use_count": snippet.use_count}, 200

View File

@ -69,7 +69,6 @@ class BuiltinToolAddPayload(BaseModel):
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
type: CredentialType
visibility: str | None = None
class BuiltinToolUpdatePayload(BaseModel):
@ -278,7 +277,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
@ -294,7 +293,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -307,7 +306,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
_, tenant_id = current_account_with_tenant()
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
@ -325,7 +324,7 @@ class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -339,7 +338,6 @@ class ToolBuiltinProviderAddApi(Resource):
credentials=payload.credentials,
name=payload.name,
api_type=CredentialType.of(payload.type),
visibility=payload.visibility,
)
@ -350,7 +348,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -372,20 +370,13 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
# Optional list of credential IDs to include even if visibility would hide them
# (used when a workflow/agent node still references another member's only_me credential).
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
]
def get(self, provider):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
user=user,
include_credential_ids=include_credential_ids or None,
)
)
@ -393,7 +384,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider: str):
def get(self, provider):
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@ -793,7 +784,7 @@ class ToolPluginOAuthApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
@ -831,7 +822,7 @@ class ToolPluginOAuthApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider: str):
def get(self, provider):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
@ -868,7 +859,7 @@ class ToolOAuthCallback(Resource):
if not credentials:
raise Exception("the plugin credentials failed")
# add credentials to database — OAuth tokens default to only_me since they're personal
# add credentials to database
BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
@ -876,7 +867,6 @@ class ToolOAuthCallback(Resource):
credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2,
visibility="only_me",
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@ -888,7 +878,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
_, current_tenant_id = current_account_with_tenant()
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider(
@ -920,7 +910,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -929,7 +919,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider: str):
def delete(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -941,7 +931,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
@ -955,18 +945,13 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
]
def get(self, provider):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
tenant_id=tenant_id,
provider=provider,
user=user,
include_credential_ids=include_credential_ids or None,
)
)
@ -1166,7 +1151,7 @@ class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
@ -1195,7 +1180,7 @@ class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)

View File

@ -77,7 +77,7 @@ class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
@ -103,7 +103,7 @@ class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -119,18 +119,15 @@ class TriggerSubscriptionListApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
user=user,
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
)
)
except ValueError as e:
@ -149,7 +146,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -178,7 +175,7 @@ class TriggerSubscriptionBuilderGetApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider: str, subscription_builder_id: str):
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
@ -194,7 +191,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_builder_id: str):
def post(self, provider, subscription_builder_id):
"""Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -226,7 +223,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_builder_id: str):
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -260,7 +257,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider: str, subscription_builder_id: str):
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -283,7 +280,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_builder_id: str):
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -407,7 +404,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -489,7 +486,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider: str):
def get(self, provider):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
@ -557,7 +554,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider: str):
def get(self, provider):
"""Get OAuth client configuration for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -603,7 +600,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
def post(self, provider):
"""Configure custom OAuth client for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -629,7 +626,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
def delete(self, provider):
"""Remove custom OAuth client configuration"""
user = current_user
assert user.current_tenant_id is not None
@ -657,7 +654,7 @@ class TriggerSubscriptionVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_id: str):
def post(self, provider, subscription_id):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None

View File

@ -25,15 +25,13 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
only_edition_enterprise,
setup_required,
with_current_tenant_id,
with_current_user,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField, dump_response, to_timestamp
from libs.login import login_required
from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
@ -155,9 +153,8 @@ class TenantListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
@ -231,11 +228,11 @@ class TenantApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
current_user, _ = current_account_with_tenant()
tenant = current_user.current_tenant
if not tenant:
raise ValueError("No current tenant")
@ -259,8 +256,8 @@ class SwitchWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = SwitchWorkspacePayload.model_validate(payload)
@ -284,8 +281,8 @@ class CustomConfigWorkspaceApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
@ -311,8 +308,8 @@ class WebappLogoWorkspaceApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -352,8 +349,8 @@ class WorkspaceInfoApi(Resource):
@login_required
@account_initialization_required
# Change workspace name
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
_, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {}
args = WorkspaceInfoPayload.model_validate(payload)
@ -375,12 +372,13 @@ class WorkspacePermissionApi(Resource):
@login_required
@account_initialization_required
@only_edition_enterprise
@with_current_tenant_id
def get(self, current_tenant_id: str):
def get(self):
"""
Get workspace permission settings.
Returns permission flags that control workspace features like member invitations and owner transfer.
"""
_, current_tenant_id = current_account_with_tenant()
if not current_tenant_id:
raise ValueError("No current tenant")

View File

@ -4,7 +4,7 @@ import os
import time
from collections.abc import Callable
from functools import wraps
from typing import Any, Concatenate, overload
from typing import Concatenate
from flask import abort, request
from pydantic import BaseModel, ValidationError
@ -37,21 +37,9 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
@overload
def account_initialization_required[T, **P, R](
view: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R]: ...
@overload
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
def account_initialization_required[R](view: Callable[..., R]) -> Callable[..., R]:
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: Any, **kwargs: Any) -> R:
# The overloads keep Resource methods method-aware for pyrefly while
# preserving support for plain functions used in tests and utilities.
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check account initialization
current_user, _ = current_account_with_tenant()
if current_user.status == AccountStatus.UNINITIALIZED:
@ -230,21 +218,9 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
@overload
def setup_required[T, **P, R](
view: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R]: ...
@overload
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
def setup_required[R](view: Callable[..., R]) -> Callable[..., R]:
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: Any, **kwargs: Any) -> R:
# The overloads keep Resource methods method-aware for pyrefly while
# preserving support for plain functions used in tests and utilities.
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check setup
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
if os.environ.get("INIT_PASSWORD"):
@ -576,7 +552,7 @@ def with_current_user_id[T, **P, R](
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, current_user.id, *args, **kwargs)
return view(self, str(current_user.id), *args, **kwargs)
return decorated

View File

@ -7,7 +7,7 @@ from hmac import new as hmac_new
from flask import abort, request
from configs import dify_config
from core.db.session_factory import session_factory
from extensions.ext_database import db
from models.model import EndUser
@ -44,8 +44,6 @@ def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""Inject an EndUser for valid inner API HMAC auth, otherwise pass the request through unchanged."""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API:
@ -74,9 +72,9 @@ def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P,
if signature_base64 != token:
return view(*args, **kwargs)
with session_factory.create_session() as session:
kwargs["user"] = session.get(EndUser, user_id)
return view(*args, **kwargs)
kwargs["user"] = db.session.get(EndUser, user_id)
return view(*args, **kwargs)
return decorated

View File

@ -147,7 +147,7 @@ class AppDescribeApi(AppReadResource):
class AppListApi(Resource):
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))

View File

@ -1,13 +1,11 @@
from __future__ import annotations
from controllers.openapi.auth.conditions import (
EDITION_CE,
EDITION_EE,
HAS_ALLOWED_ROLES,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
WEBAPP_AUTH_ENABLED,
WORKSPACE_MEMBERSHIP_REQUIRED,
WORKSPACE_SCOPED,
)
from controllers.openapi.auth.data import Edition
from controllers.openapi.auth.flow import When
@ -17,18 +15,14 @@ from controllers.openapi.auth.prepare import (
load_app,
load_app_access_mode,
load_tenant,
load_tenant_from_request,
load_workspace_role,
resolve_external_user,
)
from controllers.openapi.auth.verify import (
check_acl,
check_app_api_enabled,
check_app_access,
check_membership,
check_private_app_permission,
check_scope,
check_workspace_member,
check_workspace_mismatch,
check_workspace_role,
)
from libs.oauth_bearer import TokenType
@ -36,17 +30,13 @@ account_pipeline = AuthPipeline(
prepare=[
When(PATH_HAS_APP_ID, then=load_app),
When(PATH_HAS_APP_ID, then=load_tenant),
When(WORKSPACE_MEMBERSHIP_REQUIRED, then=load_tenant_from_request),
load_account,
When(WORKSPACE_SCOPED, then=load_workspace_role),
load_account, # all tokens here are account tokens
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
],
auth=[
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
check_scope,
When(WORKSPACE_SCOPED, then=check_workspace_member),
When(PATH_HAS_APP_ID, then=check_workspace_mismatch),
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
],
@ -60,7 +50,6 @@ external_sso_pipeline = AuthPipeline(
When(PATH_HAS_APP_ID, then=load_app_access_mode),
],
auth=[
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
check_scope,
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),

View File

@ -50,11 +50,4 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
# Caller must belong to the resolved tenant: either an app-scoped path (tenant
# from the app) or an explicit workspace-membership path (tenant from request).
WORKSPACE_SCOPED = PATH_HAS_APP_ID | WORKSPACE_MEMBERSHIP_REQUIRED
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)

View File

@ -9,7 +9,7 @@ from werkzeug.exceptions import InternalServerError
from configs import dify_config
from libs.oauth_bearer import Scope, TokenType
from models.account import Account, Tenant, TenantAccountRole
from models.account import Account, Tenant
from models.model import App, EndUser
from services.enterprise.enterprise_service import WebAppAccessMode
@ -41,8 +41,6 @@ class RequestContext(BaseModel):
token_type: TokenType
scope: Scope | None = None
path_params: dict[str, str]
workspace_membership: bool = False
allowed_roles: frozenset[TenantAccountRole] | None = None
class AuthData(BaseModel):
@ -58,14 +56,10 @@ class AuthData(BaseModel):
external_identity: ExternalIdentity | None = None
path_params: dict[str, str] = Field(default_factory=dict)
allowed_roles: frozenset[TenantAccountRole] | None = None
app: App | None = None
tenant: Tenant | None = None
app_access_mode: WebAppAccessMode | None = None
tenant_role: TenantAccountRole | None = None
caller: Account | EndUser | None = None
caller_kind: Literal["account", "end_user"] | None = None

View File

@ -34,7 +34,6 @@ from libs.oauth_bearer import (
reset_auth_ctx,
set_auth_ctx,
)
from models.account import TenantAccountRole
from services.feature_service import FeatureService, LicenseStatus
@ -57,15 +56,11 @@ class AuthPipeline:
view: Callable,
*,
scope: Scope | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Any:
req_ctx = RequestContext(
token_type=identity.token_type,
scope=scope,
path_params=dict(request.view_args or {}),
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
data = AuthData(
@ -76,7 +71,6 @@ class AuthPipeline:
scopes=frozenset(identity.scopes),
tenants=dict(identity.verified_tenants),
required_scope=scope,
allowed_roles=allowed_roles,
path_params=dict(req_ctx.path_params),
external_identity=(
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
@ -127,41 +121,6 @@ class PipelineRouter:
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
def guard_workspace(
self,
*,
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Callable:
return self._make_decorator(
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=True,
allowed_roles=allowed_roles,
)
def _make_decorator(
self,
*,
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
workspace_membership: bool,
allowed_roles: frozenset[TenantAccountRole] | None,
) -> Callable:
def decorator(view: Callable) -> Callable:
@wraps(view)
@ -173,8 +132,6 @@ class PipelineRouter:
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
return decorated
@ -190,8 +147,6 @@ class PipelineRouter:
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
workspace_membership: bool = False,
allowed_roles: frozenset[TenantAccountRole] | None = None,
) -> Any:
# 404 not 403 — this edition doesn't expose the feature at all
if edition is not None and current_edition() not in edition:
@ -227,15 +182,7 @@ class PipelineRouter:
if not license_checked and Edition.EE in route.required_edition:
_check_license()
return route.pipeline._run(
identity,
args,
kwargs,
view,
scope=scope,
workspace_membership=workspace_membership,
allowed_roles=allowed_roles,
)
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:

View File

@ -1,8 +1,5 @@
from __future__ import annotations
import uuid
from flask import request
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
from controllers.openapi.auth.data import AuthData
@ -16,18 +13,16 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppAcce
def load_app(data: AuthData) -> None:
if data.app is not None:
return
app_id = data.path_params["app_id"]
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
data.app = app
def load_tenant(data: AuthData) -> None:
if data.tenant is not None:
return
if data.app is None:
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
@ -36,25 +31,7 @@ def load_tenant(data: AuthData) -> None:
data.tenant = tenant
def load_tenant_from_request(data: AuthData) -> None:
if data.tenant is not None:
return
workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
if not workspace_id:
raise NotFound("workspace not found")
try:
uuid.UUID(workspace_id)
except ValueError:
raise NotFound("workspace not found")
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise NotFound("workspace not found")
data.tenant = tenant
def load_account(data: AuthData) -> None:
if data.caller is not None:
return
account = AccountService.get_account_by_id(db.session, str(data.account_id))
if account is None:
raise Unauthorized("account not found")
@ -64,19 +41,6 @@ def load_account(data: AuthData) -> None:
data.caller_kind = "account"
def load_workspace_role(data: AuthData) -> None:
if data.tenant_role is not None:
return
if data.tenant is None or data.account_id is None:
return
if data.caller is not None and getattr(data.caller, "status", None) != "active":
return
role = TenantService.get_account_role_in_tenant(db.session, str(data.account_id), str(data.tenant.id))
if role is None:
return
data.tenant_role = role
def resolve_external_user(data: AuthData) -> None:
if data.tenant is None or data.app is None or data.external_identity is None:
raise Unauthorized("missing context for external user resolution")

View File

@ -0,0 +1,77 @@
"""Workspace role gate.
Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)`
for routes whose access depends on the caller's `TenantAccountJoin.role`
in the workspace named by the `workspace_id` path parameter.
Usage::
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
class Members(Resource):
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role() # any member
def get(self, workspace_id: str): ...
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def post(self, workspace_id: str): ...
Non-member callers get 404 (matching `GET /openapi/v1/workspaces/<id>`)
so workspace IDs do not leak across tenants. A member without one of the
allowed roles gets 403.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from werkzeug.exceptions import Forbidden, NotFound
from extensions.ext_database import db
from libs.oauth_bearer import try_get_auth_ctx
from models.account import TenantAccountRole
from services.account_service import TenantService
F = TypeVar("F", bound=Callable[..., object])
def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]:
"""Gate a route on the caller's role in ``workspace_id``.
Pass no roles to require only membership. Pass one or more roles to
require the caller's role be in that set.
"""
allowed = frozenset(allowed_roles)
def deco(fn: F) -> F:
@wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
ctx = try_get_auth_ctx()
if ctx is None or ctx.account_id is None:
raise RuntimeError(
"require_workspace_role called without account-bearer context; "
"stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it"
)
workspace_id = kwargs.get("workspace_id")
if not workspace_id:
raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter")
role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id))
if role is None:
raise NotFound("workspace not found")
if allowed and role not in allowed:
raise Forbidden("insufficient workspace role")
return fn(*args, **kwargs)
return wrapper # type: ignore[return-value]
return deco

View File

@ -1,11 +1,10 @@
from __future__ import annotations
from flask import request
from werkzeug.exceptions import Forbidden, NotFound, UnprocessableEntity
from werkzeug.exceptions import Forbidden, Unauthorized
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
from services.account_service import AccountService, TenantService
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
@ -18,39 +17,17 @@ def check_scope(data: AuthData) -> None:
raise Forbidden("insufficient_scope")
def check_workspace_member(data: AuthData) -> None:
"""Assert the caller belongs to the resolved tenant.
`load_workspace_role` stashes the membership role (None when the caller is
not a member or is inactive). A missing membership surfaces as 404, not
403, so workspace IDs don't leak across tenants.
"""
if data.tenant_role is None:
raise NotFound("workspace not found")
def check_workspace_mismatch(data: AuthData) -> None:
def check_membership(data: AuthData) -> None:
if data.tenant is None:
return
request_workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
if request_workspace_id and request_workspace_id != str(data.tenant.id):
raise UnprocessableEntity("workspace_id does not match app's workspace")
def check_workspace_role(data: AuthData) -> None:
if data.allowed_roles is None:
return
if data.tenant_role is None:
raise NotFound("workspace not found")
if data.tenant_role not in data.allowed_roles:
raise Forbidden("insufficient workspace role")
def check_app_api_enabled(data: AuthData) -> None:
if data.app is None:
return
if not data.app.enable_api:
raise Forbidden("service_api_disabled")
raise Unauthorized("tenant unset")
if data.account_id is None:
raise Unauthorized("account_id unset")
check_workspace_membership(
account_id=data.account_id,
tenant_id=data.tenant.id,
token_hash=data.token_hash,
membership_cache=data.tenants,
)
def check_app_access(data: AuthData) -> None:

View File

@ -26,12 +26,7 @@ from werkzeug.exceptions import BadRequest
from configs import dify_config
from controllers.common.schema import query_params_from_model
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AccountPayload,
@ -47,6 +42,7 @@ from controllers.openapi._models import (
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
from libs.rate_limit import (
LIMIT_APPROVE_CONSOLE,
@ -54,7 +50,6 @@ from libs.rate_limit import (
LIMIT_LOOKUP_PUBLIC,
rate_limit,
)
from models import Account
from services.account_service import TenantService
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
@ -211,12 +206,11 @@ class DeviceApproveApi(Resource):
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
@with_current_user
@with_current_tenant_id
def post(self, tenant: str, account: Account):
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)

View File

@ -5,8 +5,9 @@ endpoints. Account bearers (dfoa_) see every tenant they're a member of.
External SSO bearers (dfoe_) have no account_id and so see an empty list —
that matches /openapi/v1/account.
Member-management endpoints use ``guard_workspace`` which enforces
workspace membership and optional role requirements via the auth pipeline.
Member-management endpoints are gated by both `accept_subjects` (SSO out)
and `require_workspace_role` (membership / role lookup against the path's
``workspace_id``).
"""
from __future__ import annotations
@ -36,6 +37,7 @@ from controllers.openapi._models import (
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.role_gate import require_workspace_role
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType
from models import Account, Tenant, TenantAccountJoin
@ -150,7 +152,8 @@ class WorkspaceSwitchApi(Resource):
"""
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role()
def post(self, workspace_id: str, *, auth_data: AuthData):
account = _load_account(auth_data.account_id)
@ -176,7 +179,8 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role()
def get(self, workspace_id: str, *, auth_data: AuthData):
try:
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
@ -198,11 +202,8 @@ class WorkspaceMembersApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def post(self, workspace_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberInvitePayload)
inviter = _load_account(auth_data.account_id)
@ -252,11 +253,8 @@ class WorkspaceMemberApi(Resource):
"""
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
operator = _load_account(auth_data.account_id)
tenant = _load_tenant(workspace_id)
@ -286,11 +284,8 @@ class WorkspaceMemberRoleApi(Resource):
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberRoleUpdatePayload)
operator = _load_account(auth_data.account_id)

View File

@ -28,7 +28,7 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.helper.trace_id_helper import get_external_trace_id, get_trace_session_id, omit_trace_session_id_from_payload
from core.helper.trace_id_helper import get_external_trace_id
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import UUIDStrOrEmpty
@ -41,22 +41,12 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode == "streaming"
if response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class CompletionRequestPayload(BaseModel):
inputs: dict[str, Any]
query: str = Field(default="")
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="dev")
trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping")
class ChatRequestPayload(BaseModel):
@ -68,7 +58,6 @@ class ChatRequestPayload(BaseModel):
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping")
@field_validator("conversation_id", mode="before")
@classmethod
@ -114,14 +103,9 @@ class CompletionApi(Resource):
if app_model.mode != AppMode.COMPLETION:
raise AppUnavailableError()
payload = CompletionRequestPayload.model_validate(
omit_trace_session_id_from_payload(service_api_ns.payload) or {}
)
payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
trace_session_id = get_trace_session_id(request)
if trace_session_id:
args["trace_session_id"] = trace_session_id
if external_trace_id:
args["external_trace_id"] = external_trace_id
@ -213,20 +197,17 @@ class ChatApi(Resource):
Supports conversation management and both blocking and streaming response modes.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
payload = ChatRequestPayload.model_validate(omit_trace_session_id_from_payload(service_api_ns.payload) or {})
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
trace_session_id = get_trace_session_id(request)
if trace_session_id:
args["trace_session_id"] = trace_session_id
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
streaming = payload.response_mode == "streaming"
try:
response = AppGenerateService.generate(
@ -281,7 +262,7 @@ class ChatStopApi(Resource):
def post(self, app_model: App, end_user: EndUser, task_id: str):
"""Stop a running chat message generation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppTaskService.stop_task(

View File

@ -155,7 +155,7 @@ class ConversationApi(Resource):
Supports pagination using last_id and limit parameters.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
query_args = ConversationListQuery.model_validate(request.args.to_dict())
@ -199,7 +199,7 @@ class ConversationDetailApi(Resource):
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -228,7 +228,7 @@ class ConversationRenameApi(Resource):
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -7,7 +7,6 @@ paused human input forms in workflow/chatflow runs.
import json
import logging
from collections.abc import Sequence
from flask import Response
from flask_restx import Resource
@ -19,7 +18,6 @@ from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from graphon.nodes.human_input.entities import FormInputConfig
from libs.helper import to_timestamp
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
@ -30,11 +28,11 @@ logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _jsonify_form_definition(form: Form, *, inputs: Sequence[FormInputConfig] = ()) -> Response:
definition_payload = form.get_definition().model_dump(mode="json")
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": [form_input.model_dump(mode="json") for form_input in inputs],
"inputs": definition_payload["inputs"],
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": to_timestamp(form.expiration_time),
@ -77,8 +75,7 @@ class WorkflowHumanInputFormApi(Resource):
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
service.ensure_form_active(form)
inputs = service.resolve_form_inputs(form)
return _jsonify_form_definition(form, inputs=inputs)
return _jsonify_form_definition(form)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc("submit_human_input_form")

View File

@ -56,7 +56,7 @@ class MessageListApi(Resource):
Retrieves messages with pagination support using first_id.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
query_args = MessageListQuery.model_validate(request.args.to_dict())
@ -167,7 +167,7 @@ class MessageSuggestedApi(Resource):
"""
message_id_str = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
try:

View File

@ -30,7 +30,7 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.helper.trace_id_helper import get_external_trace_id, get_trace_session_id, omit_trace_session_id_from_payload
from core.helper.trace_id_helper import get_external_trace_id
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.base import ResponseModel
@ -54,7 +54,6 @@ logger = logging.getLogger(__name__)
class WorkflowRunPayload(WorkflowRunPayloadBase):
response_mode: Literal["blocking", "streaming"] | None = None
trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping")
class WorkflowLogQuery(BaseModel):
@ -273,11 +272,8 @@ class WorkflowRunApi(Resource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
payload = WorkflowRunPayload.model_validate(omit_trace_session_id_from_payload(service_api_ns.payload) or {})
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
trace_session_id = get_trace_session_id(request)
if trace_session_id:
args["trace_session_id"] = trace_session_id
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
@ -332,11 +328,8 @@ class WorkflowRunByIdApi(Resource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
payload = WorkflowRunPayload.model_validate(omit_trace_session_id_from_payload(service_api_ns.payload) or {})
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
trace_session_id = get_trace_session_id(request)
if trace_session_id:
args["trace_session_id"] = trace_session_id
# Add workflow_id to args for AppGenerateService
args["workflow_id"] = workflow_id

View File

@ -23,7 +23,6 @@ from . import (
feature,
files,
forgot_password,
human_input_file_upload,
human_input_form,
login,
message,
@ -47,7 +46,6 @@ __all__ = [
"feature",
"files",
"forgot_password",
"human_input_file_upload",
"human_input_form",
"login",
"message",

View File

@ -4,7 +4,7 @@ from typing import Any, cast
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized
from werkzeug.exceptions import BadRequest, Unauthorized
from constants import HEADER_NAME_APP_CODE
from controllers.common import fields
@ -58,6 +58,9 @@ class AppParameterApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser):
"""Retrieve app parameters."""
if not app_model.enable_site:
raise BadRequest("Site is disabled.")
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:

View File

@ -2,7 +2,7 @@ import logging
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.fields import SimpleResultResponse
@ -37,15 +37,6 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_agent_app_streaming(*, app_mode: AppMode, response_mode: str | None) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode == "streaming"
if response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the completion")
query: str = Field(default="", description="Query text for completion")
@ -180,13 +171,13 @@ class ChatApi(WebApiResource):
)
def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = _resolve_agent_app_streaming(app_mode=app_mode, response_mode=payload.response_mode)
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:
@ -237,7 +228,7 @@ class ChatStopApi(WebApiResource):
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
def post(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppTaskService.stop_task(

View File

@ -83,7 +83,7 @@ class ConversationListApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
raw_args = request.args.to_dict()
@ -129,7 +129,7 @@ class ConversationApi(WebApiResource):
)
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -168,7 +168,7 @@ class ConversationRenameApi(WebApiResource):
)
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -206,7 +206,7 @@ class ConversationPinApi(WebApiResource):
@web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__])
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -237,7 +237,7 @@ class ConversationUnPinApi(WebApiResource):
@web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__])
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -1,212 +0,0 @@
"""HITL human input form file uploads.
This controller exposes a single public upload endpoint for both local files and
remote URLs. The caller always submits a multipart form: when a non-empty
``url`` field is present, the request follows the remote fetch flow; otherwise it
falls back to the local file upload flow.
"""
import httpx
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, ConfigDict, Field, HttpUrl
from sqlalchemy.orm import sessionmaker
import services
from controllers.common import helpers
from controllers.common.errors import (
BlockedFileExtensionError,
FileTooLargeError,
NoFileUploadedError,
RemoteFileUploadError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileResponse, FileWithSignedUrl
from graphon.file import helpers as file_helpers
from libs.exception import BaseHTTPException
from repositories.factory import DifyAPIRepositoryFactory
from services.file_service import FileService
from services.human_input_file_upload_service import (
HITL_UPLOAD_TOKEN_PREFIX,
HumanInputFileUploadService,
InvalidUploadTokenError,
)
class InvalidUploadTokenBadRequestError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Invalid upload token."
code = 400
class InvalidUploadTokenUnauthorizedError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Upload token is required."
code = 401
class InvalidUploadTokenForbiddenError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Upload token is invalid or expired."
code = 403
class HumanInputFileUploadFormPayload(BaseModel):
"""Parsed multipart form fields for HITL uploads."""
model_config = ConfigDict(extra="ignore")
url: HttpUrl | None = Field(default=None, description="Remote file URL")
register_schema_models(web_ns, HumanInputFileUploadFormPayload, FileResponse, FileWithSignedUrl)
def _create_upload_service() -> HumanInputFileUploadService:
session_factory = sessionmaker(bind=db.engine)
workflow_run_repository = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
return HumanInputFileUploadService(
session_factory=session_factory,
workflow_run_repository=workflow_run_repository,
)
def _extract_hitl_upload_token() -> str:
"""Read HITL upload token from Authorization without invoking other bearer auth chains."""
authorization = request.headers.get("Authorization")
if authorization is None:
raise InvalidUploadTokenUnauthorizedError()
parts = authorization.split()
if len(parts) != 2:
raise InvalidUploadTokenUnauthorizedError()
scheme, token = parts
if scheme.lower() != "bearer":
raise InvalidUploadTokenBadRequestError()
if not token:
raise InvalidUploadTokenUnauthorizedError()
if not token.startswith(HITL_UPLOAD_TOKEN_PREFIX):
raise InvalidUploadTokenBadRequestError()
return token
def _validate_context(service: HumanInputFileUploadService, token: str):
try:
return service.validate_upload_token(token)
except InvalidUploadTokenError as exc:
raise InvalidUploadTokenForbiddenError() from exc
def _parse_local_upload_file():
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.filename:
from controllers.common.errors import FilenameNotExistsError
raise FilenameNotExistsError()
return file
def _parse_upload_form() -> HumanInputFileUploadFormPayload:
return HumanInputFileUploadFormPayload.model_validate(request.form.to_dict(flat=True))
def _upload_local_file(context):
file = _parse_local_upload_file()
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename or "",
content=file.read(),
mimetype=file.mimetype,
user=context.owner,
source=None,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError() from exc
response = FileResponse.model_validate(upload_file, from_attributes=True)
return upload_file.id, response
def _upload_remote_file(context, url: str):
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as exc:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(exc)}")
file_info = helpers.guess_file_info_from_response(resp)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError()
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=context.owner,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError() from exc
response = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
return upload_file.id, response
@web_ns.route("/human-input-forms/files")
@web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__])
class HumanInputFileUploadApi(Resource):
def post(self):
"""Upload one local file or remote URL file for a HITL human input form."""
token = _extract_hitl_upload_token()
upload_service = _create_upload_service()
context = _validate_context(upload_service, token)
form = _parse_upload_form()
# The browser always submits multipart/form-data. A non-empty `url`
# switches the endpoint into the remote-fetch flow; otherwise the
# request must carry a local `file`.
if form.url is not None:
file_id, response = _upload_remote_file(context=context, url=str(form.url))
else:
file_id, response = _upload_local_file(context=context)
upload_service.record_upload_file(context=context, file_id=file_id)
return response.model_dump(mode="json"), 201

View File

@ -4,42 +4,27 @@ Web App Human Input Form APIs.
import json
import logging
from collections.abc import Sequence
from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
from extensions.ext_database import db
from graphon.nodes.human_input.entities import FormInputConfig
from libs.helper import RateLimiter, extract_remote_ip, to_timestamp
from models.account import TenantStatus
from models.model import App, Site
from repositories.factory import DifyAPIRepositoryFactory
from services.human_input_file_upload_service import HumanInputFileUploadService
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
class HumanInputUploadTokenResponse(BaseModel):
upload_token: str
expires_at: int
register_schema_models(web_ns, HumanInputUploadTokenResponse)
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
@ -50,20 +35,6 @@ _FORM_ACCESS_RATE_LIMITER = RateLimiter(
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
_FORM_UPLOAD_TOKEN_RATE_LIMITER = RateLimiter(
prefix="web_form_upload_token_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
def _create_upload_service() -> HumanInputFileUploadService:
session_factory = sessionmaker(bind=db.engine)
workflow_run_repository = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
return HumanInputFileUploadService(
session_factory=session_factory,
workflow_run_repository=workflow_run_repository,
)
class FormDefinitionPayload(TypedDict):
@ -75,17 +46,12 @@ class FormDefinitionPayload(TypedDict):
site: NotRequired[dict]
def _jsonify_form_definition(
form: Form,
*,
inputs: Sequence[FormInputConfig] = (),
site_payload: dict | None = None,
) -> Response:
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
"""Return the form payload (optionally with site) as a JSON response."""
definition_payload = form.get_definition().model_dump(mode="json")
definition_payload = form.get_definition().model_dump()
payload: FormDefinitionPayload = {
"form_content": definition_payload["rendered_content"],
"inputs": [i.model_dump(mode="json") for i in inputs],
"inputs": definition_payload["inputs"],
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": to_timestamp(form.expiration_time),
@ -95,33 +61,6 @@ def _jsonify_form_definition(
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
@web_ns.route("/form/human_input/<string:form_token>/upload-token")
class HumanInputFormUploadTokenApi(Resource):
"""API for issuing HITL upload tokens for active human input forms."""
def post(self, form_token: str):
"""
Issue an upload token for a human input form.
POST /api/form/human_input/<form_token>/upload-token
"""
ip_address = extract_remote_ip(request)
if _FORM_UPLOAD_TOKEN_RATE_LIMITER.is_rate_limited(ip_address):
raise WebFormRateLimitExceededError()
_FORM_UPLOAD_TOKEN_RATE_LIMITER.increment_rate_limit(ip_address)
try:
token = _create_upload_service().issue_upload_token(form_token)
except FormNotFoundError:
raise NotFoundError("Form not found")
response = HumanInputUploadTokenResponse(
upload_token=token.upload_token,
expires_at=to_timestamp(token.expires_at),
)
return response.model_dump(mode="json"), 200
@web_ns.route("/form/human_input/<string:form_token>")
class HumanInputFormApi(Resource):
"""API for getting and submitting human input forms via the web app."""
@ -150,13 +89,8 @@ class HumanInputFormApi(Resource):
service.ensure_form_active(form)
app_model, site = _get_app_site_from_form(form)
inputs = service.resolve_form_inputs(form)
return _jsonify_form_definition(
form,
inputs=inputs,
site_payload=serialize_app_site_payload(app_model, site, None),
)
return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None))
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
def post(self, form_token: str):

View File

@ -83,7 +83,7 @@ class MessageListApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
raw_args = request.args.to_dict()
@ -225,7 +225,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
)
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id_str = str(message_id)

View File

@ -9,7 +9,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from core.file import remote_fetcher
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
@ -60,10 +60,10 @@ class RemoteFileInfoApi(WebApiResource):
HTTPException: If the remote file cannot be accessed
"""
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = remote_fetcher.make_request("HEAD", decoded_url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
@ -112,9 +112,9 @@ class RemoteFileUploadApi(WebApiResource):
url = str(payload.url)
try:
resp = remote_fetcher.make_request("HEAD", url=url)
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
@ -125,7 +125,7 @@ class RemoteFileUploadApi(WebApiResource):
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
upload_file = FileService(db.engine).upload_file(

View File

@ -237,9 +237,7 @@ class EasyUIBasedAppConfig(AppConfig):
"""
app_model_config_from: EasyUIBasedAppModelConfigFrom
# Optional: an Agent App has no legacy app_model_config row, so the id may be
# absent (persistence then stores NULL for the conversation's id).
app_model_config_id: str | None = None
app_model_config_id: str
app_model_config_dict: dict[str, Any]
model: ModelConfigEntity
prompt_template: PromptTemplateEntity

View File

@ -1,6 +1,4 @@
import json
import re
from typing import Any
from core.app.app_config.entities import RagPipelineVariableEntity
from graphon.variables.input_entities import VariableEntity
@ -22,32 +20,10 @@ class WorkflowVariablesConfigManager:
# variables
for variable in user_input_form:
cls._normalize_json_schema(variable)
variables.append(VariableEntity.model_validate(variable))
return variables
@staticmethod
def _normalize_json_schema(variable: dict[str, Any]) -> None:
"""
Normalize ``json_schema`` from a JSON string to a dict.
The workflow graph is stored as JSON in the database. When a JSON
object variable carries a ``json_schema`` field, nested dicts are
preserved correctly, but older data or certain serialization paths
may store it as a JSON *string* instead of a native dict.
``VariableEntity.json_schema`` expects ``dict | None``, so we
deserialize the string here before handing it to Pydantic.
"""
json_schema = variable.get("json_schema")
if isinstance(json_schema, str):
try:
variable["json_schema"] = json.loads(json_schema)
except (json.JSONDecodeError, TypeError):
# Leave as-is; Pydantic validation will surface the error.
pass
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
"""

View File

@ -40,7 +40,7 @@ from core.app.entities.task_entities import (
ChatbotAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_trace_session_id_from_args
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
@ -64,12 +64,6 @@ from services.workflow_draft_variable_service import (
logger = logging.getLogger(__name__)
def _extract_trace_session_id_from_debug_args(args: Mapping[str, Any] | Any) -> dict[str, str]:
if isinstance(args, Mapping):
return extract_trace_session_id_from_args(args)
return extract_trace_session_id_from_args({"trace_session_id": getattr(args, "trace_session_id", None)})
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int
@ -146,7 +140,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
extras = {
"auto_generate_conversation_name": args.get("auto_generate_name", False),
**extract_external_trace_id_from_args(args),
**extract_trace_session_id_from_args(args),
}
# get conversation
@ -338,10 +331,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={
"auto_generate_conversation_name": False,
**_extract_trace_session_id_from_debug_args(args),
},
extras={"auto_generate_conversation_name": False},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
@ -427,10 +417,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={
"auto_generate_conversation_name": False,
**_extract_trace_session_id_from_debug_args(args),
},
extras={"auto_generate_conversation_name": False},
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs),
)
contexts.plugin_tool_providers.set({})

View File

@ -131,7 +131,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
invoke_from=invoke_from,
user_from=user_from,
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
@ -140,7 +139,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
user_id=self.application_generate_entity.user_id,
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
)
else:
inputs = self.application_generate_entity.inputs
@ -201,7 +199,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_from=user_from,
invoke_from=invoke_from,
root_node_id=root_node_id,
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
)
db.session.close()

View File

@ -1,106 +0,0 @@
"""Build the EasyUI-style app config for an Agent App from its Agent Soul.
An Agent App has no legacy ``app_model_config``: its model / prompt live in the
bound Agent Soul snapshot. To ride the existing chat message + SSE pipeline we
synthesize an ``app_model_config``-shaped dict from the Soul (model + system
prompt) plus any app-level feature flags (opening statement, follow-up, …)
stored on ``app_model_config`` when present, then reuse the same sub-managers
the chat app type uses.
"""
from typing import Any, cast
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
from core.app.app_config.entities import (
EasyUIBasedAppConfig,
EasyUIBasedAppModelConfigFrom,
PromptTemplateEntity,
)
from models.agent_config_entities import AgentSoulConfig
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
class AgentAppConfig(EasyUIBasedAppConfig):
"""Agent App config entity (EasyUI-shaped so it rides the chat pipeline).
``app_model_config_id`` is inherited as ``str | None``: an Agent App may have
no legacy ``app_model_config`` row, in which case persistence stores ``NULL``
for the conversation's ``app_model_config_id``.
"""
class AgentAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls,
*,
app_model: App,
agent_soul: AgentSoulConfig,
app_model_config: AppModelConfig | None = None,
conversation: Conversation | None = None,
) -> AgentAppConfig:
"""Build the Agent App config from the Agent Soul (+ optional feature flags)."""
config_dict = cls._synthesize_config_dict(agent_soul, app_model_config)
# The synthesized dict is shaped like an app_model_config; the EasyUI
# sub-managers type their param as AppModelConfigDict (a TypedDict).
typed_config = cast(AppModelConfigDict, config_dict)
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
# The config is derived from the Agent Soul snapshot, not a legacy
# app_model_config row; the id is informational only.
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
app_model_config_id=app_model_config.id if app_model_config else None,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=typed_config),
prompt_template=PromptTemplateConfigManager.convert(config=typed_config),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=typed_config),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
config=typed_config
)
return app_config
@staticmethod
def _synthesize_config_dict(
agent_soul: AgentSoulConfig,
app_model_config: AppModelConfig | None,
) -> dict[str, Any]:
"""Shape a Soul + feature flags into an ``app_model_config``-style dict.
Feature flags (opening statement / follow-up / tts / stt / citations /
moderation / annotation) come from ``app_model_config`` when present
(Q3: stored there), otherwise defaults; model + prompt always come from
the Agent Soul (the single source of truth for those).
"""
base: dict[str, Any] = dict(app_model_config.to_dict()) if app_model_config else {}
model = agent_soul.model
if model is not None:
base["model"] = {
"provider": model.model_provider,
"name": model.model,
"mode": "chat",
"completion_params": model.model_settings.model_dump(mode="json", exclude_none=True),
}
# The Agent Soul system prompt rides the EasyUI "simple" prompt slot; the
# agent backend is the real prompt authority, this only feeds the chat
# pipeline's bookkeeping (token counting, persistence).
base["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
base["pre_prompt"] = agent_soul.prompt.system_prompt or ""
# Agent App takes the user message directly; no completion-style inputs form.
base.setdefault("user_input_form", [])
return base
__all__ = ["AgentAppConfig", "AgentAppConfigManager"]

View File

@ -1,329 +0,0 @@
"""Agent App generator: orchestrate one conversation turn for an Agent App.
Mirrors the agent_chat generator (conversation + message + queue + streamed
response over the EasyUI chat pipeline), but the backing config comes from the
bound Agent Soul and the answer is produced by ``AgentAppRunner`` calling the
dify-agent backend rather than an in-process LLM/ReAct loop.
"""
from __future__ import annotations
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any
from flask import Flask, current_app
from sqlalchemy import select
from clients.agent_backend import AgentBackendRunEventAdapter
from clients.agent_backend.factory import create_agent_backend_run_client
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.apps.agent_app.app_config_manager import AgentAppConfigManager
from core.app.apps.agent_app.app_runner import AgentAppRunner
from core.app.apps.agent_app.generate_response_converter import AgentAppGenerateResponseConverter
from core.app.apps.agent_app.runtime_request_builder import AgentAppRuntimeRequestBuilder
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import (
AgentAppGenerateEntity,
DifyRunContext,
InvokeFrom,
UserFrom,
)
from core.app.llm.model_access import build_dify_model_access
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models import Account, App, EndUser, Message
from models.agent import Agent, AgentConfigSnapshot, AgentScope, AgentSource, AgentStatus
from models.agent_config_entities import AgentSoulConfig
from services.conversation_service import ConversationService
logger = logging.getLogger(__name__)
class AgentAppGeneratorError(ValueError):
"""Raised when an Agent App turn cannot be set up."""
class AgentAppGenerator(MessageBasedAppGenerator):
def generate(
self,
*,
app_model: App,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]:
if not streaming:
raise AgentAppGeneratorError("Agent App only supports streaming mode")
query = args.get("query")
if not isinstance(query, str) or not query.strip():
raise AgentAppGeneratorError("query is required")
query = query.replace("\x00", "")
inputs = args["inputs"]
# Resolve the bound roster Agent + its published Agent Soul snapshot.
agent, snapshot, agent_soul = self._resolve_agent(app_model)
conversation = None
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
# Build the EasyUI-shaped config from the Agent Soul so the chat pipeline
# can persist usage; the answer itself comes from the agent backend.
app_model_config = app_model.app_model_config
app_config = AgentAppConfigManager.get_app_config(
app_model=app_model,
agent_soul=agent_soul,
app_model_config=app_model_config,
conversation=conversation,
)
model_conf = ModelConfigConverter.convert(app_config)
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
application_generate_entity = AgentAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=model_conf,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=[],
parent_message_id=(
args.get("parent_message_id")
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
else UUID_NIL
),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras={
"auto_generate_conversation_name": args.get("auto_generate_name", True),
},
call_depth=0,
trace_manager=trace_manager,
agent_id=agent.id,
agent_config_snapshot_id=snapshot.id,
)
conversation, message = self._init_generate_records(application_generate_entity, conversation)
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"user_from": UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER,
},
)
worker_thread.start()
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return AgentAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,
*,
flask_app: Flask,
context: contextvars.Context,
application_generate_entity: AgentAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
user_from: UserFrom,
) -> None:
from libs.flask_utils import preserve_flask_contexts
with preserve_flask_contexts(flask_app, context_vars=context):
try:
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
app_config = application_generate_entity.app_config
# Apply app-level input guards (content moderation + annotation
# reply) before reaching the Agent backend, mirroring the EasyUI
# chat / agent-chat runners. These can short-circuit the turn.
app_model = db.session.get(App, app_config.app_id)
if app_model is None:
raise AgentAppGeneratorError("App not found")
handled, query = self._run_input_guards(
application_generate_entity=application_generate_entity,
app_model=app_model,
message=message,
queue_manager=queue_manager,
)
if handled:
return
dify_context = DifyRunContext(
tenant_id=app_config.tenant_id,
app_id=app_config.app_id,
user_id=application_generate_entity.user_id,
user_from=user_from,
invoke_from=application_generate_entity.invoke_from,
)
credentials_provider, _ = build_dify_model_access(dify_context)
_, _, agent_soul = self._resolve_agent_by_id(
tenant_id=app_config.tenant_id,
agent_id=application_generate_entity.agent_id,
snapshot_id=application_generate_entity.agent_config_snapshot_id,
)
runner = AgentAppRunner(
request_builder=AgentAppRuntimeRequestBuilder(credentials_provider=credentials_provider),
agent_backend_client=create_agent_backend_run_client(
base_url=dify_config.AGENT_BACKEND_BASE_URL,
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
),
event_adapter=AgentBackendRunEventAdapter(),
session_store=AgentAppRuntimeSessionStore(),
)
runner.run(
dify_context=dify_context,
agent_id=application_generate_entity.agent_id,
agent_config_snapshot_id=application_generate_entity.agent_config_snapshot_id,
agent_soul=agent_soul,
conversation_id=conversation.id,
query=query,
message_id=message.id,
model_name=application_generate_entity.model_conf.model,
queue_manager=queue_manager,
)
except GenerateTaskStoppedError:
pass
except Exception as e:
logger.exception("Unknown Error in Agent App generate worker")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _run_input_guards(
self,
*,
application_generate_entity: AgentAppGenerateEntity,
app_model: App,
message: Message,
queue_manager: AppQueueManager,
) -> tuple[bool, str]:
"""Apply input moderation + annotation reply before the backend call.
Returns ``(handled, query)``: when ``handled`` is True a direct answer
has already been published (a blocked/preset moderation response or a
matched annotation) and the backend turn must be skipped. Otherwise
``query`` is the possibly moderation-overridden query to send onward.
"""
from core.app.apps.agent_app.app_runner import publish_text_answer
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
app_config = application_generate_entity.app_config
model_name = application_generate_entity.model_conf.model
query = application_generate_entity.query
# content moderation (sensitive_word_avoidance); a blocked input yields a
# preset answer, an "overridden" action returns a sanitized query.
try:
_, _, query = InputModeration().check(
app_id=app_config.app_id,
tenant_id=app_config.tenant_id,
app_config=app_config,
inputs=dict(application_generate_entity.inputs),
query=query or "",
message_id=message.id,
trace_manager=application_generate_entity.trace_manager,
)
except ModerationError as e:
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=str(e))
return True, query
# annotation reply: a matching annotation answers the turn deterministically.
if query:
annotation_reply = AnnotationReplyFeature().query(
app_record=app_model,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER,
)
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=annotation_reply.content)
return True, query
return False, query
def _resolve_agent(self, app_model: App) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
agent = db.session.scalar(
select(Agent).where(
Agent.app_id == app_model.id,
Agent.scope == AgentScope.ROSTER,
Agent.source == AgentSource.AGENT_APP,
Agent.status == AgentStatus.ACTIVE,
)
)
if agent is None:
raise AgentAppGeneratorError("Agent App has no bound Agent")
return self._resolve_agent_by_id(
tenant_id=app_model.tenant_id, agent_id=agent.id, snapshot_id=agent.active_config_snapshot_id
)
@staticmethod
def _resolve_agent_by_id(
*, tenant_id: str, agent_id: str, snapshot_id: str | None
) -> tuple[Agent, AgentConfigSnapshot, AgentSoulConfig]:
agent = db.session.scalar(select(Agent).where(Agent.id == agent_id, Agent.tenant_id == tenant_id))
if agent is None:
raise AgentAppGeneratorError("Agent not found")
if not snapshot_id:
raise AgentAppGeneratorError("Agent has no published version")
snapshot = db.session.scalar(select(AgentConfigSnapshot).where(AgentConfigSnapshot.id == snapshot_id))
if snapshot is None:
raise AgentAppGeneratorError("Agent published version not found")
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
return agent, snapshot, agent_soul
__all__ = ["AgentAppGenerator", "AgentAppGeneratorError"]

View File

@ -1,200 +0,0 @@
"""Agent App runner: drive one conversation turn through the dify-agent backend.
Unlike the legacy ``AgentChatAppRunner`` (which runs an in-process ReAct loop),
this runner delegates to the Agent backend: build the run request from the
Agent Soul + conversation, create the run, consume its event stream, and
republish the assistant answer as chat queue events so the existing
EasyUI chat task pipeline persists the message and streams SSE. The conversation
``session_snapshot`` is saved on success for multi-turn continuity (S3).
"""
from __future__ import annotations
import json
import logging
from typing import Any
from pydantic import JsonValue
from clients.agent_backend import (
AgentBackendError,
AgentBackendInternalEventType,
AgentBackendRunClient,
AgentBackendRunEventAdapter,
AgentBackendRunSucceededInternalEvent,
AgentBackendStreamInternalEvent,
)
from core.app.apps.agent_app.runtime_request_builder import (
AgentAppRuntimeBuildContext,
AgentAppRuntimeRequestBuilder,
)
from core.app.apps.agent_app.session_store import AgentAppRuntimeSessionStore, AgentAppSessionScope
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import DifyRunContext
from core.app.entities.queue_entities import QueueLLMChunkEvent, QueueMessageEndEvent
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
from models.agent_config_entities import AgentSoulConfig
logger = logging.getLogger(__name__)
def publish_text_answer(*, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
"""Publish a complete assistant answer as one chunk + message-end.
The EasyUI chat task pipeline consumes a QueueLLMChunkEvent stream followed
by a QueueMessageEndEvent; emitting the whole answer as a single chunk lets
both the backend-produced answer and short-circuited answers (moderation /
annotation reply) share the exact same persistence + SSE path.
"""
chunk = LLMResultChunk(
model=model_name,
prompt_messages=[],
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)),
)
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_name,
prompt_messages=[],
message=AssistantPromptMessage(content=answer),
usage=LLMUsage.empty_usage(),
),
),
PublishFrom.APPLICATION_MANAGER,
)
class AgentAppRunner:
"""Runs one Agent App conversation turn against the Agent backend."""
def __init__(
self,
*,
request_builder: AgentAppRuntimeRequestBuilder,
agent_backend_client: AgentBackendRunClient,
event_adapter: AgentBackendRunEventAdapter,
session_store: AgentAppRuntimeSessionStore,
) -> None:
self._request_builder = request_builder
self._agent_backend_client = agent_backend_client
self._event_adapter = event_adapter
self._session_store = session_store
def run(
self,
*,
dify_context: DifyRunContext,
agent_id: str,
agent_config_snapshot_id: str,
agent_soul: AgentSoulConfig,
conversation_id: str,
query: str,
message_id: str,
model_name: str,
queue_manager: AppQueueManager,
) -> None:
scope = AgentAppSessionScope(
tenant_id=dify_context.tenant_id,
app_id=dify_context.app_id,
conversation_id=conversation_id,
agent_id=agent_id,
agent_config_snapshot_id=agent_config_snapshot_id,
)
session_snapshot = self._session_store.load_active_snapshot(scope)
runtime = self._request_builder.build(
AgentAppRuntimeBuildContext(
dify_context=dify_context,
agent_id=agent_id,
agent_config_snapshot_id=agent_config_snapshot_id,
agent_soul=agent_soul,
conversation_id=conversation_id,
user_query=query,
idempotency_key=message_id,
session_snapshot=session_snapshot,
)
)
create_response = self._agent_backend_client.create_run(runtime.request)
terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager)
if not isinstance(terminal, AgentBackendRunSucceededInternalEvent):
error = getattr(terminal, "error", None) or "Agent backend run did not complete successfully."
raise AgentBackendError(str(error))
answer = self._extract_answer(terminal.output)
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
self._save_session(scope=scope, backend_run_id=terminal.run_id, snapshot=terminal.session_snapshot)
def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager):
terminal = None
for public_event in self._agent_backend_client.stream_events(run_id):
if queue_manager.is_stopped():
self._cancel_run(run_id)
raise GenerateTaskStoppedError()
for internal_event in self._event_adapter.adapt(public_event):
if queue_manager.is_stopped():
self._cancel_run(run_id)
raise GenerateTaskStoppedError()
if internal_event.type in (
AgentBackendInternalEventType.RUN_STARTED,
AgentBackendInternalEventType.STREAM_EVENT,
):
# Stream deltas are accumulated by the backend into the
# terminal output; token-level forwarding is an S3 refinement.
if isinstance(internal_event, AgentBackendStreamInternalEvent):
continue
continue
terminal = internal_event
break
if terminal is not None:
break
return terminal
def _cancel_run(self, run_id: str) -> None:
try:
self._agent_backend_client.cancel_run(run_id)
except Exception:
logger.warning("Failed to cancel stopped Agent App backend run: run_id=%s", run_id, exc_info=True)
def _publish_answer(self, *, queue_manager: AppQueueManager, model_name: str, answer: str) -> None:
# MVP: emit the full answer as a single chunk + message-end. The chat
# task pipeline streams the chunk over SSE and persists the message.
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer)
def _save_session(self, *, scope: AgentAppSessionScope, backend_run_id: str, snapshot: Any) -> None:
try:
self._session_store.save_active_snapshot(scope=scope, backend_run_id=backend_run_id, snapshot=snapshot)
except Exception:
logger.warning(
"Failed to persist Agent App conversation session snapshot: "
"tenant_id=%s app_id=%s conversation_id=%s agent_id=%s",
scope.tenant_id,
scope.app_id,
scope.conversation_id,
scope.agent_id,
exc_info=True,
)
@staticmethod
def _extract_answer(output: JsonValue) -> str:
"""Normalize the backend's terminal output to assistant text.
Free-text Agent Apps return a plain string; if a structured output is
configured the value is a JSON object, which we serialize so the chat
message always has a string body.
"""
if isinstance(output, str):
return output
if isinstance(output, dict):
text = output.get("text")
if isinstance(text, str):
return text
return json.dumps(output, ensure_ascii=False)
return json.dumps(output, ensure_ascii=False)
__all__ = ["AgentAppRunner", "publish_text_answer"]

View File

@ -1,15 +0,0 @@
"""Response converter for the Agent App type.
The Agent App streams the same chatbot response shape as the chat / agent-chat
app types, so it reuses that converter wholesale; kept as a distinct subclass so
the app type owns its converter and can diverge later.
"""
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
class AgentAppGenerateResponseConverter(AgentChatAppGenerateResponseConverter):
pass
__all__ = ["AgentAppGenerateResponseConverter"]

View File

@ -1,184 +0,0 @@
"""Build dify-agent run requests for one Agent App conversation turn.
Mirrors the workflow ``WorkflowAgentRuntimeRequestBuilder`` but for the Agent
App surface: the user prompt is the chat message (no workflow-node job / no
previous-node context), and multi-turn continuity flows through the
conversation-keyed ``session_snapshot`` plus the history layer.
"""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Protocol, cast
from agenton.compositor import CompositorSessionSnapshot
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.protocol import CreateRunRequest
from clients.agent_backend import (
AgentBackendAgentAppRunInput,
AgentBackendModelConfig,
AgentBackendRunRequestBuilder,
redact_for_agent_backend_log,
)
from configs import dify_config
from core.app.entities.app_invoke_entities import DifyRunContext
from core.workflow.nodes.agent_v2.plugin_tools_builder import (
WorkflowAgentPluginToolsBuilder,
WorkflowAgentPluginToolsBuildError,
)
from core.workflow.nodes.agent_v2.runtime_request_builder import build_shell_layer_config
from models.agent_config_entities import AgentSoulConfig
from models.provider_ids import ModelProviderID
class AgentAppRuntimeRequestBuildError(ValueError):
"""Raised when Agent App state cannot be mapped to a valid run request."""
def __init__(self, error_code: str, message: str) -> None:
self.error_code = error_code
super().__init__(message)
class CredentialsProvider(Protocol):
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: ...
@dataclass(frozen=True, slots=True)
class AgentAppRuntimeBuildContext:
dify_context: DifyRunContext
agent_id: str
agent_config_snapshot_id: str
agent_soul: AgentSoulConfig
conversation_id: str
user_query: str
idempotency_key: str
session_snapshot: CompositorSessionSnapshot | None = None
@dataclass(frozen=True, slots=True)
class AgentAppRuntimeRequest:
request: CreateRunRequest
redacted_request: dict[str, Any]
metadata: dict[str, Any]
class AgentAppRuntimeRequestBuilder:
"""Build dify-agent run requests from Agent App conversation state."""
def __init__(
self,
*,
credentials_provider: CredentialsProvider,
request_builder: AgentBackendRunRequestBuilder | None = None,
plugin_tools_builder: WorkflowAgentPluginToolsBuilder | None = None,
) -> None:
self._credentials_provider = credentials_provider
self._request_builder = request_builder or AgentBackendRunRequestBuilder()
self._plugin_tools_builder = plugin_tools_builder or WorkflowAgentPluginToolsBuilder()
def build(self, context: AgentAppRuntimeBuildContext) -> AgentAppRuntimeRequest:
agent_soul = context.agent_soul
if agent_soul.model is None:
raise AgentAppRuntimeRequestBuildError(
"agent_model_not_configured",
"Agent App requires the Agent Soul model to be configured.",
)
metadata = self._build_metadata(context)
credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model)
try:
tools_layer = self._plugin_tools_builder.build(
tenant_id=context.dify_context.tenant_id,
app_id=context.dify_context.app_id,
user_id=context.dify_context.user_id,
tools=agent_soul.tools,
invoke_from=context.dify_context.invoke_from,
)
except WorkflowAgentPluginToolsBuildError as error:
raise AgentAppRuntimeRequestBuildError(error.error_code, str(error)) from error
if tools_layer is not None or agent_soul.tools.cli_tools:
metadata["agent_tools"] = {
"dify_tool_count": len(tools_layer.tools) if tools_layer is not None else 0,
"dify_tool_names": [tool.name or tool.tool_name for tool in tools_layer.tools]
if tools_layer is not None
else [],
"cli_tool_count": len(agent_soul.tools.cli_tools),
}
request = self._request_builder.build_for_agent_app(
AgentBackendAgentAppRunInput(
model=AgentBackendModelConfig(
plugin_id=self._plugin_daemon_plugin_id(
plugin_id=agent_soul.model.plugin_id,
model_provider=agent_soul.model.model_provider,
),
model_provider=self._plugin_daemon_provider_name(agent_soul.model.model_provider),
model=agent_soul.model.model,
credentials=self._normalize_credentials(credentials),
model_settings=agent_soul.model.model_settings.model_dump(mode="json", exclude_none=True),
),
execution_context=DifyExecutionContextLayerConfig(
tenant_id=context.dify_context.tenant_id,
user_id=context.dify_context.user_id,
app_id=context.dify_context.app_id,
conversation_id=context.conversation_id,
agent_id=context.agent_id,
agent_config_version_id=context.agent_config_snapshot_id,
invoke_from="agent_app",
),
agent_soul_prompt=agent_soul.prompt.system_prompt or None,
user_prompt=context.user_query,
tools=tools_layer,
include_shell=dify_config.AGENT_SHELL_ENABLED,
shell_config=build_shell_layer_config(agent_soul),
session_snapshot=context.session_snapshot,
idempotency_key=context.idempotency_key,
metadata=metadata,
)
)
redacted = cast(dict[str, Any], redact_for_agent_backend_log(request))
return AgentAppRuntimeRequest(request=request, redacted_request=redacted, metadata=metadata)
@staticmethod
def _build_metadata(context: AgentAppRuntimeBuildContext) -> dict[str, Any]:
return {
"tenant_id": context.dify_context.tenant_id,
"app_id": context.dify_context.app_id,
"conversation_id": context.conversation_id,
"agent_id": context.agent_id,
"agent_config_snapshot_id": context.agent_config_snapshot_id,
}
@staticmethod
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
"""Return the transport plugin id expected by plugin-daemon headers."""
if plugin_id.count("/") == 1:
return plugin_id
if plugin_id:
return ModelProviderID(plugin_id).plugin_id
return ModelProviderID(model_provider).plugin_id
@staticmethod
def _plugin_daemon_provider_name(model_provider: str) -> str:
"""Return the provider name expected by plugin-daemon dispatch payloads."""
return ModelProviderID(model_provider).provider_name
@staticmethod
def _normalize_credentials(credentials: Mapping[str, Any]) -> dict[str, str | int | float | bool | None]:
normalized: dict[str, str | int | float | bool | None] = {}
for key, value in credentials.items():
if isinstance(value, str | int | float | bool) or value is None:
normalized[key] = value
else:
normalized[key] = str(value)
return normalized
__all__ = [
"AgentAppRuntimeBuildContext",
"AgentAppRuntimeRequest",
"AgentAppRuntimeRequestBuildError",
"AgentAppRuntimeRequestBuilder",
]

View File

@ -1,146 +0,0 @@
"""Conversation-keyed Agent backend session store for the Agent App type.
Shares the unified ``agent_runtime_sessions`` table with the workflow Agent
Node store, but owns rows with ``owner_type = conversation``: one Agent App
conversation maps to one Agent session, so multi-turn chat re-enters the same
``session_snapshot``. Cross-conversation memory (PRD Global / Per app) is a
phase-2 concern and not modeled here.
"""
from __future__ import annotations
from dataclasses import dataclass
from agenton.compositor import CompositorSessionSnapshot
from sqlalchemy import select
from core.db.session_factory import session_factory
from libs.datetime_utils import naive_utc_now
from models.agent import (
AgentRuntimeSession,
AgentRuntimeSessionOwnerType,
AgentRuntimeSessionStatus,
)
@dataclass(frozen=True, slots=True)
class AgentAppSessionScope:
"""Identity of one Agent App conversation session."""
tenant_id: str
app_id: str
conversation_id: str
agent_id: str
agent_config_snapshot_id: str
class AgentAppRuntimeSessionStore:
"""Persists Agent backend session snapshots for Agent App conversations."""
def load_active_snapshot(self, scope: AgentAppSessionScope) -> CompositorSessionSnapshot | None:
with session_factory.create_session() as session:
row = session.scalar(self._active_stmt(scope))
if row is None:
return None
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
def load_active_snapshot_for_conversation(
self, *, tenant_id: str, app_id: str, conversation_id: str
) -> CompositorSessionSnapshot | None:
"""Load a conversation's active snapshot without the agent/config scope.
One Agent App conversation maps to one active session, so the workspace
inspector can resolve it from the conversation alone (it does not know
which agent config version a past turn ran under).
"""
stmt = (
select(AgentRuntimeSession)
.where(
AgentRuntimeSession.owner_type == AgentRuntimeSessionOwnerType.CONVERSATION,
AgentRuntimeSession.tenant_id == tenant_id,
AgentRuntimeSession.app_id == app_id,
AgentRuntimeSession.conversation_id == conversation_id,
AgentRuntimeSession.status == AgentRuntimeSessionStatus.ACTIVE,
)
.order_by(AgentRuntimeSession.updated_at.desc())
)
with session_factory.create_session() as session:
row = session.scalar(stmt)
if row is None:
return None
return CompositorSessionSnapshot.model_validate_json(row.session_snapshot)
def save_active_snapshot(
self,
*,
scope: AgentAppSessionScope,
backend_run_id: str,
snapshot: CompositorSessionSnapshot | None,
) -> None:
if snapshot is None:
return
snapshot_json = snapshot.model_dump_json()
with session_factory.create_session() as session:
row = session.scalar(self._scope_stmt(scope))
if row is None:
row = AgentRuntimeSession(
tenant_id=scope.tenant_id,
app_id=scope.app_id,
owner_type=AgentRuntimeSessionOwnerType.CONVERSATION,
agent_id=scope.agent_id,
agent_config_snapshot_id=scope.agent_config_snapshot_id,
conversation_id=scope.conversation_id,
backend_run_id=backend_run_id,
session_snapshot=snapshot_json,
composition_layer_specs="[]",
status=AgentRuntimeSessionStatus.ACTIVE,
)
session.add(row)
else:
row.backend_run_id = backend_run_id
row.session_snapshot = snapshot_json
row.status = AgentRuntimeSessionStatus.ACTIVE
row.cleaned_at = None
session.flush()
other_rows = session.scalars(
select(AgentRuntimeSession).where(
AgentRuntimeSession.owner_type == AgentRuntimeSessionOwnerType.CONVERSATION,
AgentRuntimeSession.tenant_id == scope.tenant_id,
AgentRuntimeSession.app_id == scope.app_id,
AgentRuntimeSession.conversation_id == scope.conversation_id,
AgentRuntimeSession.status == AgentRuntimeSessionStatus.ACTIVE,
AgentRuntimeSession.id != row.id,
)
).all()
for other_row in other_rows:
other_row.status = AgentRuntimeSessionStatus.CLEANED
other_row.cleaned_at = naive_utc_now()
session.commit()
def mark_cleaned(self, *, scope: AgentAppSessionScope, backend_run_id: str | None = None) -> None:
with session_factory.create_session() as session:
row = session.scalar(self._active_stmt(scope))
if row is None:
return
if backend_run_id is not None:
row.backend_run_id = backend_run_id
row.status = AgentRuntimeSessionStatus.CLEANED
row.cleaned_at = naive_utc_now()
session.commit()
@staticmethod
def _scope_stmt(scope: AgentAppSessionScope):
return select(AgentRuntimeSession).where(
AgentRuntimeSession.owner_type == AgentRuntimeSessionOwnerType.CONVERSATION,
AgentRuntimeSession.tenant_id == scope.tenant_id,
AgentRuntimeSession.conversation_id == scope.conversation_id,
AgentRuntimeSession.agent_id == scope.agent_id,
AgentRuntimeSession.agent_config_snapshot_id == scope.agent_config_snapshot_id,
)
@classmethod
def _active_stmt(cls, scope: AgentAppSessionScope):
return cls._scope_stmt(scope).where(AgentRuntimeSession.status == AgentRuntimeSessionStatus.ACTIVE)
__all__ = ["AgentAppRuntimeSessionStore", "AgentAppSessionScope"]

View File

@ -20,7 +20,6 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.helper.trace_id_helper import extract_trace_session_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -97,10 +96,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get("auto_generate_name", True),
**extract_trace_session_id_from_args(args),
}
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
# get conversation
conversation = None

View File

@ -134,10 +134,6 @@ class AppQueueManager(ABC):
self._check_for_sqlalchemy_models(event.model_dump())
self._publish(event, pub_from)
def is_stopped(self) -> bool:
"""Return whether the current task has been manually stopped."""
return self._is_stopped()
@abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""

View File

@ -20,7 +20,6 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.helper.trace_id_helper import extract_trace_session_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -90,10 +89,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get("auto_generate_name", True),
**extract_trace_session_id_from_args(args),
}
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
# get conversation
conversation = None

View File

@ -52,11 +52,15 @@ from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import (
HumanInputSurface,
enrich_human_input_pause_reasons,
resolve_human_input_pause_reason_inputs,
)
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
# Maps the entry surface a workflow was invoked from to the HITL surface that
# its resume tokens must be filtered for. Surfaces not in this map fall back to
# the general priority ordering (typically CONSOLE > BACKSTAGE).
_INVOKE_FROM_TO_HITL_SURFACE: Mapping[InvokeFrom, HumanInputSurface] = {
InvokeFrom.SERVICE_API: HumanInputSurface.SERVICE_API,
InvokeFrom.OPENAPI: HumanInputSurface.OPENAPI,
}
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -79,14 +83,6 @@ from models.human_input import HumanInputForm
from models.workflow import WorkflowRun
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
# Maps the entry surface a workflow was invoked from to the HITL surface that
# its resume tokens must be filtered for. Surfaces not in this map fall back to
# the general priority ordering (typically CONSOLE > BACKSTAGE).
_INVOKE_FROM_TO_HITL_SURFACE: Mapping[InvokeFrom, HumanInputSurface] = {
InvokeFrom.SERVICE_API: HumanInputSurface.SERVICE_API,
InvokeFrom.OPENAPI: HumanInputSurface.OPENAPI,
}
NodeExecutionId = NewType("NodeExecutionId", str)
logger = logging.getLogger(__name__)
@ -331,13 +327,8 @@ class WorkflowResponseConverter:
encoded_outputs = self._encode_outputs(event.outputs) or {}
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
encoded_outputs = {}
variable_pool = graph_runtime_state.variable_pool
resolved_reasons = resolve_human_input_pause_reason_inputs(
event.reasons,
variable_pool=variable_pool,
)
pause_reasons = [reason.model_dump(mode="json") for reason in resolved_reasons]
human_input_form_ids = [reason.form_id for reason in resolved_reasons if isinstance(reason, HumanInputRequired)]
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
expiration_times_by_form_id: dict[str, datetime] = {}
display_in_ui_by_form_id: dict[str, bool] = {}
form_token_by_form_id: dict[str, str] = {}
@ -374,7 +365,7 @@ class WorkflowResponseConverter:
responses: list[StreamResponse] = []
for reason in resolved_reasons:
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
expiration_time = expiration_times_by_form_id.get(reason.form_id)
if expiration_time is None:
@ -422,19 +413,17 @@ class WorkflowResponseConverter:
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
) -> HumanInputFormFilledResponse:
run_id = self._ensure_workflow_run_id()
data = HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
return HumanInputFormFilledResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
),
)
if event.submitted_data is not None:
runtime_type_converter = WorkflowRuntimeTypeConverter()
data.submitted_data = runtime_type_converter.value_to_json_encodable_recursive(event.submitted_data)
return HumanInputFormFilledResponse(task_id=task_id, workflow_run_id=run_id, data=data)
def human_input_form_timeout_to_stream_response(
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str

View File

@ -20,7 +20,6 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.helper.trace_id_helper import extract_trace_session_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -149,9 +148,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras={
**extract_trace_session_id_from_args(args),
},
extras={},
trace_manager=trace_manager,
)

View File

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
@ -32,11 +32,7 @@ from core.app.entities.task_entities import (
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import (
extract_external_trace_id_from_args,
extract_parent_trace_context_from_args,
extract_trace_session_id_from_args,
)
from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_parent_trace_context_from_args
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
@ -61,13 +57,26 @@ SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
logger = logging.getLogger(__name__)
def _extract_trace_session_id_from_debug_args(args: Mapping[str, Any] | Any) -> dict[str, str]:
if isinstance(args, Mapping):
return extract_trace_session_id_from_args(args)
return extract_trace_session_id_from_args({"trace_session_id": getattr(args, "trace_session_id", None)})
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
if workflow.kind_or_standard != "snippet":
return workflow
from models.snippet import CustomizedSnippet
from services.snippet_generate_service import SnippetGenerateService
snippet = session.scalar(
select(CustomizedSnippet).where(
CustomizedSnippet.id == workflow.app_id,
CustomizedSnippet.tenant_id == workflow.tenant_id,
)
)
if snippet is None:
return workflow
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@ -177,7 +186,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
extras = {
**extract_external_trace_id_from_args(args),
**extract_parent_trace_context_from_args(args),
**extract_trace_session_id_from_args(args),
}
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
@ -421,10 +429,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={
"auto_generate_conversation_name": False,
**_extract_trace_session_id_from_debug_args(args),
},
extras={"auto_generate_conversation_name": False},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
@ -510,10 +515,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={
"auto_generate_conversation_name": False,
**_extract_trace_session_id_from_debug_args(args),
},
extras={"auto_generate_conversation_name": False},
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs or {}),
workflow_execution_id=str(uuid.uuid4()),
)
@ -592,6 +594,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
if workflow is None:
raise ValueError("Workflow not found")
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,

View File

@ -87,7 +87,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_from=user_from,
invoke_from=invoke_from,
root_node_id=self._root_node_id,
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
@ -95,7 +94,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
user_id=self.application_generate_entity.user_id,
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
)
else:
inputs = self.application_generate_entity.inputs
@ -130,7 +128,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_from=user_from,
invoke_from=invoke_from,
root_node_id=root_node_id,
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
)
# RUN WORKFLOW

View File

@ -118,7 +118,6 @@ class WorkflowBasedAppRunner:
tenant_id: str = "",
user_id: str = "",
root_node_id: str | None = None,
trace_session_id: str | None = None,
) -> Graph:
"""
Init graph
@ -139,7 +138,6 @@ class WorkflowBasedAppRunner:
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
trace_session_id=trace_session_id,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow_id,
@ -173,7 +171,6 @@ class WorkflowBasedAppRunner:
single_loop_run: Any | None = None,
*,
user_id: str,
trace_session_id: str | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
@ -211,7 +208,6 @@ class WorkflowBasedAppRunner:
node_type_filter_key="iteration_id",
node_type_label="iteration",
user_id=user_id,
trace_session_id=trace_session_id,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
@ -222,7 +218,6 @@ class WorkflowBasedAppRunner:
node_type_filter_key="loop_id",
node_type_label="loop",
user_id=user_id,
trace_session_id=trace_session_id,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
@ -241,7 +236,6 @@ class WorkflowBasedAppRunner:
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
*,
user_id: str = "",
trace_session_id: str | None = None,
) -> tuple[Graph, VariablePool]:
"""
Get graph and variable pool for single node execution (iteration or loop).
@ -307,7 +301,6 @@ class WorkflowBasedAppRunner:
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
trace_session_id=trace_session_id,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id,
@ -442,7 +435,6 @@ class WorkflowBasedAppRunner:
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
submitted_data=event.submitted_data,
)
)
case NodeRunHumanInputFormTimeoutEvent():

Some files were not shown because too many files have changed in this diff Show More