Compare commits

..

203 Commits

Author SHA1 Message Date
24876bb05d fix: avoid duplicating lines when merging text for summarization (#37093)
Co-authored-by: bymle <229636660+bymle@users.noreply.github.com>
2026-06-05 07:02:57 +00:00
yyh
0cdd478f25 fix(web): stabilize block selector layout (#37089) 2026-06-05 07:00:03 +00:00
0db9714eb6 fix(web): attach Amplitude user ID before firing registration event (#37091)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-05 06:31:27 +00:00
yyh
9da4d167fa fix(explore): render human input preview handles (#37086) 2026-06-05 03:32:29 +00:00
a1ad4be61e fix(api): expose device-flow approve rate limit as env var (#37083) 2026-06-05 02:56:23 +00:00
8cb2cffbf7 feat: improve output node (#35511) 2026-06-05 02:14:23 +00:00
yyh
a8f009a965 fix(ui): align form control focus rings (#37069) 2026-06-04 14:12:28 +00:00
0bfbd2061e feat: enhance go to anything (#32130)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-04 11:06:17 +00:00
c8abb11bf0 feat: support custom trace session id for Phoenix tracing (#37056)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-04 08:42:03 +00:00
yyh
f9320b2c91 fix(api): return agent timestamps as epoch seconds (#37057)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-04 08:27:37 +00:00
f0fd7ddb60 feat(cli): unified help system (#36896)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 07:27:28 +00:00
b77f5f1e4a fix: agent tool selector marketplace checks for local and builtin tools (#37037) 2026-06-04 06:04:09 +00:00
b67c3a5f76 refactor(api): migrate tenant/user via DI for several endpoints (#37026) 2026-06-04 05:52:59 +00:00
5b5a06136a fix(agent): complete CLI-tool + env shell bootstrap & add composer validation (ENG-367/368) (#37033)
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-04 05:46:42 +00:00
6e3c9597ff chore(i18n): sync translations with en-US (#37035)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
2026-06-04 02:31:52 +00:00
3c98f96ae8 feat(api): introduce select, file and file list form input types to Human Input node (#36322)
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: GPT 5.4 <codex@openai.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-06-04 01:54:28 +00:00
44725dde74 feat(agent): Sandbox / CLI Agent (dify.shell) + read-only sandbox file inspector (#36984)
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 22:37:31 +00:00
d3058d63bd refactor(api): migrate console.datasets.data_source to BaseModel (#36624) 2026-06-03 19:38:39 +00:00
4fc62d3b38 refactor(api): migrate console.datasets.rag_pipeline partially to BaseModel (#36649) 2026-06-03 17:44:10 +00:00
e14cb209a4 chore: add missing @override decorator to api/core/rag/extractor (#37013)
Co-authored-by: mac <mac@1234.local>
2026-06-03 12:34:10 +00:00
bb3c9929f9 chore: add missing @override decorators to api/libs (#37012) 2026-06-03 12:17:50 +00:00
35a55813d2 chore(i18n): sync translations with en-US (#37011)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
2026-06-03 09:42:37 +00:00
a247d625e5 chore(deps): bump pyjwt to 2.13.0 (#37008) 2026-06-03 09:39:58 +00:00
yyh
5c7f05bd10 fix(web): auth form state management (#37003) 2026-06-03 09:14:01 +00:00
02e1a60cde chore: add missing @override decorator to api/configs (#37006)
Co-authored-by: mac <mac@1234.local>
2026-06-03 09:11:50 +00:00
57b573d02b refactor(api): migrate tenant/user via DI for several endpoints (#37004) 2026-06-03 08:59:00 +00:00
yyh
9de40e8f21 chore: update Claude skill links (#36997) 2026-06-03 08:00:35 +00:00
cad0942f4d fix(api): enforce workspace membership + role checks in auth pipeline (#36931)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 07:31:47 +00:00
cb9b1b593e feat: add Milvus TLS env examples (#36980) 2026-06-03 07:16:18 +00:00
2a8bdc2373 fix: pydantic_core._pydantic_core.ValidationError: 2 validation errors for DatasetDetailResponse (#36753) 2026-06-03 07:10:55 +00:00
ee6a07d13c refactor: use explicit session in inner api user auth (#36995) 2026-06-03 07:06:38 +00:00
yyh
2d6c9300e3 fix(api): tighten agent v2 generated contracts (#36989)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 06:52:40 +00:00
d6b4c800c2 refactor(web): migrate account education notice storage (#36991)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 06:39:22 +00:00
yyh
1b37635f92 fix: configure server console api url (#36958)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 06:22:46 +00:00
86af36429d fix: create app from template modal has no backdrop (#36987)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 06:14:46 +00:00
b96ea94505 chore: add :str to <path: parameter (#36913)
Co-authored-by: 99 <wh2099@pm.me>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 05:25:11 +00:00
d649cccda0 chore: add missing @override decorato to api/extensions (#36941)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-06-03 05:25:08 +00:00
5cbbd78f38 refactor(web): migrate chat sidebar collapse storage (#36963)
Co-authored-by: lmlm <7487674+popsiclelmlm@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 04:40:48 +00:00
5a0ad4ecd9 fix: normalize json_schema from string to dict in VariableEntity (#36777) 2026-06-03 04:33:25 +00:00
1e76b9e1b8 refactor(web): migrate workflow-node-panel-width to useSetLocalStorage (#36983)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 04:32:41 +00:00
1b972c4e09 refactor(api): migrate tenant/user via DI for several endpoints (#36971)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 04:24:17 +00:00
7968d2c3c8 refactor(web): migrate workflow-variable-inpsect-panel-height to useSetLocalStorage (#36982)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 03:48:59 +00:00
7507e9ba67 refactor(web): migrate debug-and-preview-panel-width to useSetLocalStorage (#36977)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 03:27:15 +00:00
y
ca31762e26 refactor(web): migrate education verifying storage to useLocalStorage (#36934)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-03 02:16:59 +00:00
f591da7865 ci: ruff cover agent (#36949) 2026-06-02 11:40:19 +00:00
f19679b217 refactor: improve network error and allow verbose output (#36923) 2026-06-02 10:43:40 +00:00
b682591c7a refactor(web): migrate question classifier label hint storage (#36932)
Co-authored-by: lmlm <7487674+popsiclelmlm@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 10:28:50 +00:00
8f6b59feff refactor(web): migrate rag recommendations collapsed storage (#36940)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 09:08:51 +00:00
99833f65d8 refactor(web): migrate NEED_REFRESH_APP_LIST_KEY to useLocalStorage/useSetLocalStorage (#36908)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-06-02 08:41:01 +00:00
yyh
696fc5c213 refactor(web): manage goto anything open state with atom (#36938) 2026-06-02 08:23:18 +00:00
eae44cfecb feat(dify-agent): add shell layer (#36838) 2026-06-02 07:54:52 +00:00
yyh
dea4e66456 fix(web): use generated account-profile contracts (#36927)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 07:28:05 +00:00
3cd0da303a refactor: remove unused Flask-RESTX field dicts from end_user and conversation_variable fields (#28015) (#36929) 2026-06-02 07:27:23 +00:00
888483a2f8 fix: user token (#36930) 2026-06-02 07:20:07 +00:00
7056985f72 refactor: inject current user id in stop message endpoints (#36925)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 06:48:10 +00:00
6ce61eae59 fix(cli): invalidate app metadata cache on 422 to clear stale data (#36921) 2026-06-02 05:20:33 +00:00
yyh
079af312c6 fix(contracts): include account avatar url in profile schema (#36924)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 04:30:47 +00:00
0da13dfe4d refactor(cli): unify token storage behind Store + add host/account switching (#36830) 2026-06-02 04:05:53 +00:00
1ff4d75084 refactor(web): migrate anthropic quota notice storage (#36922)
Co-authored-by: lmlm <7487674+popsiclelmlm@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2026-06-02 04:05:15 +00:00
e35d23c3cb feat(api): Agent App type S1 — AppMode.AGENT + create flow + binding (#36829)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 03:50:10 +00:00
e530e84772 refactor(web): migrate NOTE_SHOW_AUTHOR_STORAGE_KEY to useLocalStorage/useSetLocalStorage (#36915)
Signed-off-by: Cocoon-Break <54054995+kuishou68@users.noreply.github.com>
Co-authored-by: lingxiu58 <86288566+lingxiu58@users.noreply.github.com>
Co-authored-by: pojian68 <232320289+pojian68@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2026-06-02 03:44:47 +00:00
2257a4f1ef refactor(web): migrate workflow featured collapsed storage (#36918)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 03:40:59 +00:00
yyh
f465dc5090 fix(web): defer react-scan loader (#36920) 2026-06-02 03:34:55 +00:00
5c1cfe6ada chore: ignore .vinext (#36914) 2026-06-02 02:43:15 +00:00
8d401d84c7 chore(api): adjust migration timestamp metadata for a1b2c3d4e5f6 (#36910) 2026-06-02 02:22:47 +00:00
b74287c2ab chore: update deps (#36911) 2026-06-02 02:10:59 +00:00
c64d3e98c4 fix(tools): use short-lived sessions for icon lookups to prevent idle-in-transaction (#36903)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-02 01:59:10 +00:00
yyh
a3265f722e docs: add client state guidelines (#36900) 2026-06-01 11:44:50 +00:00
5658065b97 test: satisfy strict pyrefly for migrated container tests (#36791) 2026-06-01 11:09:40 +00:00
yyh
8fc2807194 feat(web): create system-features vertical (#36894) 2026-06-01 10:15:25 +00:00
fc7716704d chore: not request system-features for cloud edition (#36891)
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-06-01 09:31:16 +00:00
71ffaacb58 fix(api): centralize remote file retrieval (#36399)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-06-01 09:25:08 +00:00
cfc1cf2b8c refactor(cli/http): replace ky with a self-contained HTTP client (#36711)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-01 09:04:42 +00:00
yyh
055d9b9f0a refactor(web): migrate local storage hook usage (#36890) 2026-06-01 08:20:13 +00:00
yyh
21711bebeb refactor(web): migrate local storage access to react hook (#36888) 2026-06-01 07:57:54 +00:00
yyh
becccbf288 fix(web): read pnpm config env in standalone start (#36887) 2026-06-01 07:18:50 +00:00
86497045c9 feat: per-credential visibility control for plugin credentials (#35468)
Co-authored-by: Yang <yang@Yangs-MacBook-Pro.local>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-01 05:56:18 +00:00
687a177b24 chore: add override decorators to core repositories (#36885) 2026-06-01 05:24:21 +00:00
4a6d278354 refactor(web): mark workflow run props readonly (#36857) 2026-06-01 05:06:21 +00:00
yyh
7d69302e9f chore: update deps (#36884) 2026-06-01 04:28:04 +00:00
yyh
bcd573e560 fix(web): respect marketplace feature flag in model selector (#36883) 2026-06-01 04:11:58 +00:00
yyh
07c0c4e7b1 chore(web): remove TanStack devtools (#36882) 2026-06-01 03:57:50 +00:00
yyh
a8a2ca7b98 chore(cli): move eslint config into cli package (#36878) 2026-06-01 03:54:14 +00:00
de47d43b65 refactor: convert isinstance chains to match/case syntax (#36862)
Co-authored-by: krishkantiuj-ren <hiccup.cc.3@gmail.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-06-01 03:45:19 +00:00
240912cef5 fix(api): preserve hierarchical estimate rules (#36852)
Co-authored-by: root <kinsonnee@gmail.com>
2026-06-01 03:16:09 +00:00
72e040ead3 docs: add security policy (#36873) 2026-06-01 09:58:32 +08:00
c0ee821d45 refactor: use absolute path for inter dir importing (#36822)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-01 01:32:16 +00:00
c7c3296572 fix: MCP search results include only MCP providers (#36871)
Co-authored-by: LL201314-II <you@example.com>
2026-06-01 01:13:51 +00:00
e7be04fd58 fix(api): dedup EndUser in plugin get_user by session_id for Reverse Invocation (#36742) 2026-06-01 00:57:29 +00:00
df6b5be50a refactor: convert isinstance chains to match/case (part 5) (#36503) 2026-05-31 15:08:59 +00:00
8e5f09091b refactor: convert if isinstance chains to match case (#36846)
Co-authored-by: duongynhi000005-oss <duongynhi000005-oss@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-05-31 15:05:43 +00:00
0a3005701f refactor: inject current user into user-only controllers (#36754)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-05-31 15:03:15 +00:00
d8571ce965 refactor: convert isinstance chains to match/case (part 4) (#36274)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-05-31 14:44:17 +00:00
f241ae25be fix: #36585 dep inject current user id (#36845)
Co-authored-by: duongynhi000005-oss <duongynhi000005-oss@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 14:37:39 +00:00
c6474a2a8b refactor: convert isinstance chains to match/case (part 8) (#36869) 2026-05-31 14:11:05 +00:00
yyh
480d05bc48 fix(web): prefetch workspace and guard routes with contract query (#36870) 2026-05-31 14:02:00 +00:00
yyh
f75725ccd9 feat(web): add server oRPC client (#36856) 2026-05-31 13:14:28 +00:00
yyh
2fe8c48255 refactor(web): scope workflow hotkeys (#36860)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 13:14:13 +00:00
ec5404cc9d chore: split trial models to a single API (#36796)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 13:09:13 +00:00
yyh
20f62b9919 fix(web): use generated current workspace query (#36843)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 13:04:18 +00:00
04f5555580 chore: split to single app_dsl_version API (#36864)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-31 12:13:44 +00:00
129af96c23 chore: add missing @override decorators to pipeline WorkflowAppGenerateResponseConverter (#36859)
Co-authored-by: krishkantiuj-ren <hiccup.cc.3@gmail.com>
2026-05-31 12:02:17 +00:00
df40960f5d chore: dep inject for model (#36750)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
2026-05-30 17:40:46 +00:00
599960024d refactor(api): migrate console/service_api.dataset.document to BaseModel (#36506)
Co-authored-by: WH-2099 <wh2099@pm.me>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-30 14:38:27 +00:00
yyh
6805d9bfc0 fix(auth): reset profile query after login (#36851) 2026-05-30 14:34:04 +00:00
928f888ef5 refactor(api): migrate console/service_api.dataset.segment to BaseModel (#36522)
Co-authored-by: WH-2099 <wh2099@pm.me>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-30 13:54:01 +00:00
yyh
f46c03460e fix(auth): avoid leaking request origin in refresh redirects (#36847) 2026-05-30 05:55:18 +00:00
0b60338ad5 chore: reuse injected SQLAlchemy sessions in app read paths (#36798) 2026-05-30 00:23:58 +00:00
yyh
91ac465982 fix(web): use default profile query cache (#36832) 2026-05-29 14:18:39 +00:00
yyh
9490d63c50 refactor(web): remove app initializer and move auth boot logic to route boundaries (#36818) 2026-05-29 12:26:34 +00:00
ae538ced47 chore: using single SSH_SCRIPT for saas dev (#36827) 2026-05-29 10:07:15 +00:00
487249728b fix: remove unnecessary # type: ignore comments (#24494) (#36825) 2026-05-29 09:41:32 +00:00
372a2e3e9c refactor: convert isinstance chains to match/case (part 7) (#35902) (#36826) 2026-05-29 09:40:33 +00:00
4939a9c33d refactor: add ts common style check for web and cli (#36823) 2026-05-29 09:26:32 +00:00
b6f92f1dc4 fix(cli): fix style (#36821) 2026-05-29 08:34:36 +00:00
ce276573a8 chore: deploy saas dev workflow (#36819) 2026-05-29 08:30:55 +00:00
5070cc9668 refactor(cli): optimize error handling in flag parsing (#36810) 2026-05-29 07:39:26 +00:00
a392a72960 chore: not store search tag condition in url (#36814) 2026-05-29 07:30:35 +00:00
30270b5c30 fix(device): surface SSO errors on /device and fix CLI null-account crash on external-SSO login (#36781) 2026-05-29 06:51:34 +00:00
24715a9570 chore: unified plugin status icon position (#36816) 2026-05-29 06:45:25 +00:00
c530a5d272 fix(api): validate annotation list pagination query (#36807)
Co-authored-by: root <kinsonnee@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-29 06:25:48 +00:00
418ee7398e fix: install failed plugin dose not show icon (#36811) 2026-05-29 06:07:43 +00:00
78f40c0d25 test: stabilize modal context pricing test (#36524) 2026-05-29 05:19:37 +00:00
2cc567c6a3 feat: add DTO for agent api (#36797)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-29 03:36:41 +00:00
a180ab19e4 chore: type check test container tests (#36790) 2026-05-29 01:54:25 +00:00
13eaa436e7 test: isolate Redis state in container tests (#36740) 2026-05-28 12:42:25 +00:00
3596d12e4c refactor(cli): use Store interface as token storage (#36726) 2026-05-28 10:02:51 +00:00
e8de10a3b5 feat(docker): add missing OPENAPI_* env vars to shared.env.example (#36752) 2026-05-28 08:52:03 +00:00
f5ab5e7eb3 fix: fix cannot extract elements from a scalar (#36769) 2026-05-28 07:31:36 +00:00
0c40e1c2a0 feat: add cross-environment app migration workflow (#36765)
Co-authored-by: XW <wei.xu1@wiz.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-28 07:30:33 +00:00
c29d76757e docs(api): fix typo in vector migration docstrings (#36741) 2026-05-28 07:15:34 +00:00
91c1d3ad81 fix: handle null plugin badges (#36767) 2026-05-28 07:00:32 +00:00
57b02e341c refactor: add @override decorators to storage backend subclasses (#36406) (#36755) 2026-05-28 06:04:47 +00:00
b94ff65e9f fix(docker): copy dify-agent source into production stage (#36757) 2026-05-28 06:01:11 +00:00
678260e34e test: migrate workspace members tests to containers (#36738)
Co-authored-by: jamesrayammons <63717587+jamesrayammons@users.noreply.github.com>
2026-05-28 06:01:05 +00:00
739e34d08a fix(docker): pin web docker node version (#36756) 2026-05-28 05:25:41 +00:00
825fb9cb89 chore(codeowners): add Riskey for service API docs (#36731) 2026-05-28 05:06:12 +00:00
0e1f19a380 refactor: inject tenant id in tenant-only console handlers (#36751) 2026-05-28 03:50:28 +00:00
332d1ea533 chore: install dify-agent as editable (#36735) 2026-05-28 01:26:06 +00:00
9cdeffd0b1 feat(api): agent backend session lifecycle for workflow agent nodes (#36724)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-27 15:00:21 +00:00
09ef785a20 test: move delete account task to container integration (#36733)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-27 13:58:58 +00:00
d2788d7aba feat(openapi): redesign auth pipeline with per-token-type routing (#36693)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-27 12:45:30 +00:00
yyh
cee90a4e82 feat(ui): add kbd primitive (#36729) 2026-05-27 11:58:13 +00:00
b2710b875b refactor: use match case for draft variable serialization (#36716)
Co-authored-by: unknown <EI05187@apwx.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-27 09:59:28 +00:00
6464255d33 fix: fix DocumentSegment.keywords can not a valid json (#36715) 2026-05-27 08:42:48 +00:00
yyh
50face5760 fix(ui): chip style (#36720) 2026-05-27 08:30:43 +00:00
b034449a0c refactor(api): migrate console/service_api.dataset.hit_testing to BaseModel (#36533) 2026-05-27 06:51:42 +00:00
a8d380bcaf refactor(cli): add kvstore and platform interface (#36687) 2026-05-27 05:30:12 +00:00
bee21c9f86 feat(api): support explicit TLS for Milvus vector store (#36265) 2026-05-27 05:17:27 +00:00
cab215e209 fix(web): add loading skeletons for tools and knowledge lists (#36712) 2026-05-27 05:07:40 +00:00
7ae4ca9a60 chore: add pnpm-managed node runtime (#36531) 2026-05-27 04:49:37 +00:00
d342ff1a1e refactor: convert isinstance chains to match/case (part 6) (#36705)
Signed-off-by: EvanYao826 <155432245+EvanYao826@users.noreply.github.com>
2026-05-27 04:09:01 +00:00
4384d8910e chore(api): polishhelp output for legacy-model-types migration script (#36707) 2026-05-27 03:29:08 +00:00
yyh
fc773b9f57 chore(web): restrict legacy service fetch imports (#36701) 2026-05-27 03:08:35 +00:00
6e1e0d9439 feat(openapi,cli): workspace switch + member management (#36651)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-27 03:05:47 +00:00
5c5a6e83e5 feat(api): introduce model-type migration script (#36520) 2026-05-27 02:12:11 +00:00
yyh
dade318f00 fix(tools): improve custom collection modal scrolling (#36694) 2026-05-27 02:07:50 +00:00
ebff9a3639 feat: add agent backend plugin layer (#36686)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-27 02:03:51 +00:00
yyh
58b8fc21d4 fix(plugin): align local install modal spacing (#36689)
Co-authored-by: wangxiaolei <fatelei@gmail.com>
2026-05-27 01:12:57 +00:00
e0ad088657 chore: add App type annotations to api endpoints (#36675) 2026-05-26 15:35:48 +00:00
323b2b82e0 chore: add EndUser and App type annotations to api endpoints (#36677) 2026-05-26 09:43:00 +00:00
7d45335a32 fix(chat): close streaming LLM generator when stop response is triggered (#36227)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-26 09:23:26 +00:00
f5d664887b chore: backend feature api exclude_vector_space (#36642)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-26 08:50:54 +00:00
5aa24c25d9 chore: add InstalledApp type annotations to api endpoints (#36678) 2026-05-26 08:32:38 +00:00
eed8d659d1 refactor(api): migrate tenant/user via DI: apikey, extension, data_source_bearer, oauth_server (#36660) 2026-05-26 08:22:35 +00:00
59e99ee1ae refactor(api): migrate console tags to tenant/user via DI and improve tests (#36658)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-26 08:20:10 +00:00
yyh
533929d314 fix(dify-ui): align picker stories with Base UI (#36680) 2026-05-26 07:59:59 +00:00
fb07b43107 feat(api): Node Output Inspector service + 3 REST endpoints (Stage 4 §8) (#36644)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-26 07:34:33 +00:00
0dad426101 chore: add dependabot to lts branch (#36424) 2026-05-26 07:08:08 +00:00
2a1df4de62 chore(deps): bump boto3 from 1.43.10 to 1.43.14 in /api in the storage group (#36595)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-26 06:47:59 +00:00
2b97f6c8c2 chore: inject tenant id in extension handlers (#36656) 2026-05-26 05:45:03 +00:00
75d6511284 chore: inject account context in file handlers (#36655) 2026-05-26 05:43:57 +00:00
fd059720e5 chore: inject tenant id in feature handlers (#36654) 2026-05-26 05:36:02 +00:00
2a5f7bb1aa chore: inject current user in explore message handlers (#36652) 2026-05-26 05:31:51 +00:00
0f06aa2fdd feat(dify-agent): sync agent progress (#36633)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-26 03:14:10 +00:00
yyh
884e2b864b feat(dify-ui): add textarea primitive (#36547) 2026-05-26 02:33:32 +00:00
a728e0ac69 feat: adding dify cli (#36348)
Co-authored-by: GareArc <garethcxy@dify.ai>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: L1nSn0w <l1nsn0w@qq.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: gigglewang <gigglewang@dify.ai>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
2026-05-26 01:12:36 +00:00
7d464d014c fix: remove unused datasource_parameters from Notion pre-import query (#36627) 2026-05-26 01:05:30 +00:00
0ce0127e7e fix(security): reject path traversal sequences before plugin daemon forward (GHSA-gvc6-fh3x-89xh) (#35796)
Co-authored-by: Ido Shani <ido@zafran.io>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-25 16:17:39 +00:00
25da7ae0d9 chore: dep inject for sql session (#36545)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
2026-05-25 14:24:58 +00:00
4d6f8eba2a fix: normalize summary_index_setting None to fix preview error (#36626) 2026-05-25 13:42:45 +00:00
87268f0662 chore: inject current user in console handlers (#36628) 2026-05-25 13:14:08 +00:00
135e01930b chore: example of current user id dep injection (#36588)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 11:31:40 +00:00
yyh
fe86fa31ec fix: normalize app icon picker dialog state (#36621) 2026-05-25 10:39:52 +00:00
b1f0a11d84 feat: output declaration and inspector (#36618)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 10:08:58 +00:00
fbfb4b3a00 chore: use dify_config.BILLING_ENABLED (#36619) 2026-05-25 09:41:01 +00:00
3a467d1d63 fix: member invite limits with dedup, locking, and accurate new-member counting (#36512) 2026-05-25 08:58:42 +00:00
yyh
23539c5bcc feat(dify-ui): add status and progress primitives (#36615) 2026-05-25 08:31:52 +00:00
9ddd98a265 fix(api): preserve dataset nested null shapes (#36611)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
2026-05-25 08:06:33 +00:00
yyh
ecfee2f072 fix: center align slider thumb (#36614) 2026-05-25 07:55:30 +00:00
345ba80942 fix: type mismatches (route says uuid: but handler says str) (#36612) 2026-05-25 07:33:32 +00:00
e617435d03 fix: replace .distinct() with .group_by(Conversation.id) for PostgreSQL JSON compatibility (#36610)
Co-authored-by: cocoon <kuishou68@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 07:15:24 +00:00
5f7eb7bde9 feat: add workflow_version to workflow_agent_node_bindings (#36603)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 06:26:19 +00:00
yyh
eb41c9b769 chore: upgrade dependencies (#36606) 2026-05-25 05:42:35 +00:00
yyh
8876efb419 refactor(dify-ui): rename toggle group to segmented control (#36605) 2026-05-25 04:57:39 +00:00
adb14d23de feat(dify-agent): add history layer and structural output layer (#36600) 2026-05-25 04:28:17 +00:00
6f1623e02a chore(i18n): sync translations with en-US (#36599)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2026-05-25 03:06:45 +00:00
67d99723ea fix: External retrieval model response rejects empty score threshold bug (#36577)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 03:01:06 +00:00
639e12a306 fix: request /api/datasets raise exception (#36591)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 02:27:54 +00:00
yyh
ed17b6161f refactor(dify-ui): refine switch contract (#36539) 2026-05-25 02:22:43 +00:00
yyh
baf0cf8e4e chore(web): remove select-auto in body (#36554) 2026-05-25 02:22:39 +00:00
yyh
1e9c94b788 fix(web): clean up header logo accessibility (#36567) 2026-05-25 02:22:34 +00:00
yyh
ffd336cfe8 feat: add and unify pagination components across UI and app surfaces (#36569)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 02:22:31 +00:00
1953 changed files with 131409 additions and 22296 deletions

View File

@ -1 +0,0 @@
../../.agents/skills/frontend-query-mutation

View File

@ -0,0 +1 @@
../../.agents/skills/how-to-write-component

15
.dockerignore Normal file
View File

@ -0,0 +1,15 @@
**/node_modules
**/.pnpm-store
**/dist
**/.next
**/.turbo
**/.cache
**/__pycache__
**/*.pyc
**/.mypy_cache
**/.ruff_cache
.git
.github
*.md
!web/README.md
!api/README.md

4
.gitattributes vendored
View File

@ -5,3 +5,7 @@
# them.
*.sh text eol=lf
# Codegen output must stay byte-identical across platforms so
# `pnpm tree:check` in CI does not trip on CRLF rewrites.
*.generated.ts text eol=lf

5
.github/CODEOWNERS vendored
View File

@ -18,6 +18,10 @@
# Docs
/docs/ @crazywoola
# CLI
/cli/ @langgenius/maintainers
/.github/workflows/cli-tests.yml @langgenius/maintainers
# Backend (default owner, more specific rules below will override)
/api/ @QuantumGhost
@ -162,6 +166,7 @@
# Frontend - App - API Documentation
/web/app/components/develop/ @JzoNgKVO @iamjoel
/web/app/components/develop/template/*.mdx @JzoNgKVO @iamjoel @RiskeyL
# Frontend - App - Logs and Annotations
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel

111
.github/dependabot.yml vendored
View File

@ -110,3 +110,114 @@ updates:
github-actions-dependencies:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
target-branch: "lts/1.13.x"
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "github-actions"
directory: "/"
target-branch: "lts/1.13.x"
open-pull-requests-limit: 5
schedule:
interval: "weekly"
groups:
github-actions-dependencies:
patterns:
- "*"

View File

@ -51,6 +51,15 @@ jobs:
with:
files: |
api/**
- name: Check dify-agent inputs
if: github.event_name != 'merge_group'
id: dify-agent-changes
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
dify-agent/**/*.py
dify-agent/pyproject.toml
dify-agent/uv.lock
- if: github.event_name != 'merge_group'
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
@ -76,6 +85,17 @@ jobs:
# Format code
uv run ruff format ..
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
run: |
cd dify-agent
uv sync --dev
# fmt first to avoid line too long
uv run ruff format .
# Fix lint errors
uv run ruff check --fix .
# Format code
uv run ruff format .
- name: count migration progress
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
run: |

88
.github/workflows/cli-release.yml vendored Normal file
View File

@ -0,0 +1,88 @@
name: CLI Release
on:
workflow_dispatch:
push:
tags:
- 'difyctl-v*'
concurrency:
group: cli-release-${{ github.ref }}
cancel-in-progress: true
jobs:
release:
name: build standalone binaries (all targets)
runs-on: depot-ubuntu-24.04
if: github.repository == 'langgenius/dify'
permissions:
contents: write
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
fetch-depth: 0
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Setup Bun
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
with:
bun-version: latest
- name: Read cli/package.json
id: manifest
run: |
version=$(node -p "require('./package.json').version")
channel=$(node -p "require('./package.json').difyctl.channel")
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
{
echo "version=$version"
echo "channel=$channel"
echo "minDify=$minDify"
echo "maxDify=$maxDify"
} >> "$GITHUB_OUTPUT"
- name: Validate manifest
run: scripts/release-validate-manifest.sh
- name: Install cross-arch native prebuilds
# Re-installs node_modules with every @napi-rs/keyring platform variant
# so `bun build --compile` can embed the right .node into each target.
working-directory: ./
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
- name: Compile standalone binaries (all targets)
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
run: |
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
pnpm build:bin
- name: Generate sha256 checksum file
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
run: scripts/release-write-checksums.sh
- name: Publish GitHub Release
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
with:
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
name: difyctl ${{ steps.manifest.outputs.version }}
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
generate_release_notes: true
fail_on_unmatched_files: true
files: |
cli/dist/bin/difyctl-v*

60
.github/workflows/cli-smoke.yml vendored Normal file
View File

@ -0,0 +1,60 @@
name: CLI Smoke (live dify)
on:
workflow_dispatch:
inputs:
dify_version:
description: "Dify image tag to test against (e.g. 1.7.0)"
type: string
required: true
cli_ref:
description: "Git ref to build the cli from (default: current branch)"
type: string
required: false
permissions:
contents: read
jobs:
smoke:
runs-on: ubuntu-latest
timeout-minutes: 30
defaults:
run:
shell: bash
steps:
- name: Checkout cli ref
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Bring up dify
env:
DIFY_VERSION: ${{ inputs.dify_version }}
run: |
cd docker
cp .env.example .env
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
docker compose up -d api worker web db redis
for i in $(seq 1 60); do
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
echo "dify api ready after ${i}s"
break
fi
sleep 1
done
- name: Run smoke against live dify
working-directory: ./cli
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
- name: Dump dify logs on failure
if: failure()
run: |
cd docker
docker compose logs api worker web --tail=200

50
.github/workflows/cli-tests.yml vendored Normal file
View File

@ -0,0 +1,50 @@
name: CLI Tests
on:
workflow_call:
secrets:
CODECOV_TOKEN:
required: false
permissions:
contents: read
concurrency:
group: cli-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: CLI Tests (${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [depot-ubuntu-24.04, windows-latest, macos-latest]
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: CI pipeline (typecheck, lint, coverage, build)
run: pnpm ci
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' && matrix.os == 'depot-ubuntu-24.04' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: cli/coverage
flags: cli
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

View File

@ -1,4 +1,4 @@
name: Deploy Agent Dev
name: Deploy SaaS
permissions:
contents: read
@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/agent-dev"
- "deploy/saas"
types:
- completed
@ -16,13 +16,13 @@ jobs:
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'
github.event.workflow_run.head_branch == 'deploy/saas'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
host: ${{ secrets.SAAS_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
${{ vars.SSH_SCRIPT_SAAS_DEV || secrets.SSH_SCRIPT_SAAS_DEV }}

View File

@ -42,6 +42,7 @@ jobs:
runs-on: depot-ubuntu-24.04
outputs:
api-changed: ${{ steps.changes.outputs.api }}
cli-changed: ${{ steps.changes.outputs.cli }}
e2e-changed: ${{ steps.changes.outputs.e2e }}
web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }}
@ -62,6 +63,18 @@ jobs:
- 'docker/generate_docker_compose'
- 'docker/ssrf_proxy/**'
- 'docker/volumes/sandbox/conf/**'
cli:
- 'cli/**'
- 'packages/tsconfig/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- 'eslint.config.mjs'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/cli-tests.yml'
- '.github/workflows/cli-docker-build.yml'
- '.github/actions/setup-web/**'
web:
- 'web/**'
- 'packages/**'
@ -184,6 +197,66 @@ jobs:
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
cli-tests-run:
name: Run CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
uses: ./.github/workflows/cli-tests.yml
secrets: inherit
cli-tests-skip:
name: Skip CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped CLI tests
run: echo "No CLI-related changes detected; skipping CLI tests."
cli-tests:
name: CLI Tests
if: ${{ always() }}
needs:
- pre_job
- check-changes
- cli-tests-run
- cli-tests-skip
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize CLI Tests status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
RUN_RESULT: ${{ needs.cli-tests-run.result }}
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "CLI tests ran successfully."
exit 0
fi
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "CLI tests were skipped because no CLI-related files changed."
exit 0
fi
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
web-tests-run:
name: Run Web Tests
needs:

View File

@ -95,6 +95,51 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
run: vp run knip
ts-common-style:
name: TS Common
runs-on: depot-ubuntu-24.04
permissions:
checks: write
pull-requests: read
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
cli/**
e2e/**
sdks/nodejs-client/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.nvmrc
eslint.config.mjs
.github/workflows/style.yml
.github/actions/setup-web/**
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
@ -105,28 +150,14 @@ jobs:
restore-keys: |
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
- name: Style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: .
run: vp run lint:ci
- name: Web tsslint
- name: Type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: .
run: vp run type-check
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: vp run knip
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5

12
.gitignore vendored
View File

@ -115,6 +115,12 @@ venv/
ENV/
env.bak/
venv.bak/
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
!/cli/src/env/
!/cli/src/commands/env/
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
!/cli/scripts/lib/
.conda/
# Spyder project settings
@ -247,8 +253,12 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md
*.local.toml
# Code Agent Folder
.qoder/*
.context/*
.context/
.eslintcache
# Vitest local reports
web/.vitest-reports/

27
SECURITY.md Normal file
View File

@ -0,0 +1,27 @@
# Security Policy
## Reporting a Vulnerability
If you believe you have found a security vulnerability in Dify, please report it privately through GitHub Security Advisories:
https://github.com/langgenius/dify/security/advisories/new
Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.
When submitting a report, include as much relevant information as you can safely provide, such as:
- A description of the vulnerability
- Steps to reproduce, if safe to share privately
- Affected components, versions, or configurations
- Potential impact
- Any suggested mitigation or fix, if available
The maintainers will review reports submitted through GitHub Security Advisories and coordinate follow-up there.
## Public Disclosure
Please avoid publicly disclosing details of a vulnerability until it has been reviewed and, where appropriate, a fix or mitigation has been made available.
## Security Updates
Security fixes may be released through normal project releases or other appropriate channels. Users are encouraged to keep Dify deployments up to date.

View File

@ -17,7 +17,7 @@ FROM base AS packages
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
# basic environment
g++ \
git g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
@ -27,7 +27,7 @@ COPY api/providers ./providers
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
COPY dify-agent/src /app/dify-agent/src
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
RUN uv sync --frozen --no-dev
RUN uv sync --frozen --no-dev --no-editable
# production stage
FROM base AS production

View File

@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
ext_logstore,
ext_mail,
ext_migrate,
ext_oauth_bearer,
ext_orjson,
ext_otel,
ext_proxy_fix,
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
ext_oauth_bearer,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]
@ -221,10 +223,11 @@ def initialize_extensions(app: DifyApp):
def create_migrations_app() -> DifyApp:
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate
from extensions import ext_commands, ext_database, ext_migrate
# Initialize only required extensions
ext_database.init_app(app)
ext_migrate.init_app(app)
ext_commands.init_app(app)
return app

View File

@ -30,21 +30,27 @@ from clients.agent_backend.factory import create_agent_backend_run_client
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
from clients.agent_backend.request_builder import (
AGENT_SOUL_PROMPT_LAYER_ID,
DIFY_PLUGIN_CONTEXT_LAYER_ID,
DIFY_EXECUTION_CONTEXT_LAYER_ID,
DIFY_PLUGIN_TOOLS_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
AgentBackendAgentAppRunInput,
AgentBackendModelConfig,
AgentBackendOutputConfig,
AgentBackendRunRequestBuilder,
AgentBackendWorkflowNodeRunInput,
CleanupLayerSpec,
extract_cleanup_layer_specs,
redact_for_agent_backend_log,
)
__all__ = [
"AGENT_SOUL_PROMPT_LAYER_ID",
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
"DIFY_PLUGIN_TOOLS_LAYER_ID",
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
"WORKFLOW_USER_PROMPT_LAYER_ID",
"AgentBackendAgentAppRunInput",
"AgentBackendError",
"AgentBackendHTTPError",
"AgentBackendInternalEvent",
@ -66,9 +72,11 @@ __all__ = [
"AgentBackendTransportError",
"AgentBackendValidationError",
"AgentBackendWorkflowNodeRunInput",
"CleanupLayerSpec",
"DifyAgentBackendRunClient",
"FakeAgentBackendRunClient",
"FakeAgentBackendScenario",
"create_agent_backend_run_client",
"extract_cleanup_layer_specs",
"redact_for_agent_backend_log",
]

View File

@ -20,6 +20,8 @@ from dify_agent.protocol import (
RunEvent,
RunFailedEvent,
RunFailedEventData,
RunPausedEvent,
RunPausedEventData,
RunStartedEvent,
RunStatusResponse,
RunSucceededEvent,
@ -34,6 +36,7 @@ class FakeAgentBackendScenario(StrEnum):
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
class FakeAgentBackendRunClient:
@ -89,6 +92,13 @@ class FakeAgentBackendRunClient:
updated_at=_FIXED_TIME,
error="fake failure",
)
case FakeAgentBackendScenario.PAUSED:
return RunStatusResponse(
run_id=run_id,
status="paused",
created_at=_FIXED_TIME,
updated_at=_FIXED_TIME,
)
def _events(self, run_id: str) -> tuple[RunEvent, ...]:
match self.scenario:
@ -115,3 +125,17 @@ class FakeAgentBackendRunClient:
data=RunFailedEventData(error="fake failure", reason="unit_test"),
),
)
case FakeAgentBackendScenario.PAUSED:
return (
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
RunPausedEvent(
id="2-0",
run_id=run_id,
created_at=_FIXED_TIME,
data=RunPausedEventData(
reason="human_input_required",
message="Agent requested human input.",
session_snapshot=CompositorSessionSnapshot(layers=[]),
),
),
)

View File

@ -4,29 +4,38 @@ This module is intentionally an adapter, not a wire DTO package. The emitted
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
protocol has a single owner. API-only context such as Agent Soul vs workflow job
prompt is preserved in layer names and metadata until the dedicated product
schemas land in later phases.
schemas land in later phases. Dify-owned execution identifiers are emitted as an
explicit ``dify.execution_context`` layer so the run request stays fully
composition-driven.
"""
from __future__ import annotations
from typing import ClassVar
from typing import ClassVar, cast
from agenton.compositor import CompositorSessionSnapshot
from agenton.compositor.schemas import LayerSessionSnapshot
from agenton.layers import ExitIntent
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
DifyPluginCredentialValue,
DifyPluginLayerConfig,
DifyPluginLLMLayerConfig,
DifyPluginToolsLayerConfig,
)
from dify_agent.layers.execution_context import (
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
DifyExecutionContextLayerConfig,
)
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellLayerConfig
from dify_agent.protocol import (
DIFY_AGENT_HISTORY_LAYER_ID,
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
CreateRunRequest,
ExecutionContext,
LayerExitSignals,
RunComposition,
RunLayerSpec,
@ -37,17 +46,96 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt"
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
DIFY_SHELL_LAYER_ID = "shell"
# Layer types that hold credentials in their per-run config. These are excluded
# from the cleanup-replay composition (and from the snapshot that is sent with
# the cleanup request) because we deliberately do not persist plaintext
# credentials between runs.
_CLEANUP_EXCLUDED_LAYER_TYPES: tuple[str, ...] = (
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
)
class CleanupLayerSpec(BaseModel):
"""One layer node replayed by an Agent backend cleanup-only run.
Cleanup composition cannot include credential-bearing plugin layers, so we
persist only the non-plugin layer specs together with the original config.
Storing the config (rather than just ``name``/``type``) means cleanup does
not depend on the original build-time inputs being re-derivable.
"""
name: str
type: str
deps: dict[str, str] = Field(default_factory=dict)
metadata: dict[str, JsonValue] = Field(default_factory=dict)
config: JsonValue = None
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
def extract_cleanup_layer_specs(composition: RunComposition) -> list[CleanupLayerSpec]:
"""Project the in-flight composition into the persistable cleanup spec list.
Plugin layers are intentionally dropped (their configs hold credentials and
the lifecycle contract says "do not include an LLM layer" during cleanup).
The filtered names must later drive snapshot filtering so the agenton
compositor's name-order check still passes for the cleanup run.
"""
excluded = set(_CLEANUP_EXCLUDED_LAYER_TYPES)
specs: list[CleanupLayerSpec] = []
for layer in composition.layers:
if layer.type in excluded:
continue
config_value: JsonValue = None
if isinstance(layer.config, BaseModel):
config_value = layer.config.model_dump(mode="json", warnings=False)
else:
# ``RunLayerSpec.config`` is typed as ``LayerConfigInput`` which
# includes ``Mapping[str, object] | bytes``. In the cleanup-replay
# pipeline our builder only emits BaseModel-derived configs or
# ``None``, so the wider input alias narrows safely here.
config_value = cast(JsonValue, layer.config)
specs.append(
CleanupLayerSpec(
name=layer.name,
type=layer.type,
deps=dict(layer.deps),
metadata=dict(layer.metadata),
config=config_value,
)
)
return specs
def _filter_snapshot_to_specs(
snapshot: CompositorSessionSnapshot,
specs: list[CleanupLayerSpec],
) -> CompositorSessionSnapshot:
"""Keep only snapshot layers whose names appear in the cleanup spec list.
The agenton compositor rejects a snapshot whose layer-name sequence does
not match the active composition exactly. Cleanup-replay drops plugin
layers, so we must drop the matching snapshot entries here.
"""
kept_names = {spec.name for spec in specs}
filtered_layers: list[LayerSessionSnapshot] = [layer for layer in snapshot.layers if layer.name in kept_names]
if len(filtered_layers) == len(snapshot.layers):
return snapshot
return CompositorSessionSnapshot(schema_version=snapshot.schema_version, layers=filtered_layers)
class AgentBackendModelConfig(BaseModel):
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
tenant_id: str
plugin_id: str
model_provider: str
model: str
user_id: str | None = None
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
@ -55,10 +143,14 @@ class AgentBackendModelConfig(BaseModel):
class AgentBackendOutputConfig(BaseModel):
"""API-side structured output declaration for the conventional output layer."""
"""API-side structured output declaration for the conventional output layer.
The structured-output tool name is fixed to ``final_output`` inside
``dify_agent.layers.output`` so callers only control the JSON Schema plus
optional description/strictness metadata.
"""
json_schema: dict[str, JsonValue]
name: str = "final_result"
description: str | None = None
strict: bool | None = None
@ -69,15 +161,21 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
model: AgentBackendModelConfig
execution_context: ExecutionContext
execution_context: DifyExecutionContextLayerConfig
workflow_node_job_prompt: str
user_prompt: str
agent_soul_prompt: str | None = None
purpose: RunPurpose = "workflow_node"
idempotency_key: str | None = None
output: AgentBackendOutputConfig | None = None
tools: DifyPluginToolsLayerConfig | None = None
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
include_shell: bool = False
shell_config: DifyShellLayerConfig | None = None
session_snapshot: CompositorSessionSnapshot | None = None
suspend_on_exit: bool = False
include_history: bool = True
suspend_on_exit: bool = True
metadata: dict[str, JsonValue] = Field(default_factory=dict)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@ -90,9 +188,198 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
return value
class AgentBackendAgentAppRunInput(BaseModel):
"""Inputs to build one Agent App conversation-turn run request.
Unlike the workflow-node input there is no workflow-node-job prompt and no
previous-node context: the user prompt is the chat message, and multi-turn
continuity comes from ``session_snapshot`` + the history layer keyed by the
conversation.
"""
model: AgentBackendModelConfig
execution_context: DifyExecutionContextLayerConfig
user_prompt: str
agent_soul_prompt: str | None = None
purpose: RunPurpose = "agent_app"
idempotency_key: str | None = None
output: AgentBackendOutputConfig | None = None
tools: DifyPluginToolsLayerConfig | None = None
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
include_shell: bool = False
shell_config: DifyShellLayerConfig | None = None
session_snapshot: CompositorSessionSnapshot | None = None
include_history: bool = True
suspend_on_exit: bool = True
metadata: dict[str, JsonValue] = Field(default_factory=dict)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@field_validator("user_prompt")
@classmethod
def _reject_blank_prompt(cls, value: str) -> str:
if not value.strip():
raise ValueError("prompt must not be blank")
return value
class AgentBackendRunRequestBuilder:
"""Converts API product state into the public ``dify-agent`` run protocol."""
def build_for_agent_app(self, run_input: AgentBackendAgentAppRunInput) -> CreateRunRequest:
"""Build an Agent App conversation-turn run request.
Layer graph: optional Agent Soul system prompt → user prompt →
execution context → optional history (multi-turn) → LLM → optional
plugin tools → optional structured output. Mirrors the workflow-node
layer ordering minus the workflow-job / previous-node prompt.
"""
layers: list[RunLayerSpec] = []
if run_input.agent_soul_prompt:
layers.append(
RunLayerSpec(
name=AGENT_SOUL_PROMPT_LAYER_ID,
type=PLAIN_PROMPT_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_soul"},
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
)
)
layers.extend(
[
RunLayerSpec(
name=AGENT_APP_USER_PROMPT_LAYER_ID,
type=PLAIN_PROMPT_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_app_user_prompt"},
config=PromptLayerConfig(user=run_input.user_prompt),
),
RunLayerSpec(
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.execution_context,
),
]
)
if run_input.include_history:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_HISTORY_LAYER_ID,
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_session_history"},
)
)
layers.append(
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=DifyPluginLLMLayerConfig(
plugin_id=run_input.model.plugin_id,
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
model_settings=run_input.model.model_settings or None,
),
)
)
if run_input.tools is not None and run_input.tools.tools:
layers.append(
RunLayerSpec(
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=run_input.tools,
)
)
if run_input.include_shell:
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
# so the spec carries no deps; shellctl connection is server-injected.
layers.append(
RunLayerSpec(
name=DIFY_SHELL_LAYER_ID,
type=DIFY_SHELL_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.shell_config or DifyShellLayerConfig(),
)
)
if run_input.output is not None:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_OUTPUT_LAYER_ID,
type=DIFY_OUTPUT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=DifyOutputLayerConfig(
json_schema=run_input.output.json_schema,
description=run_input.output.description,
strict=run_input.output.strict,
),
)
)
return CreateRunRequest(
composition=RunComposition(layers=layers),
purpose=run_input.purpose,
idempotency_key=run_input.idempotency_key,
metadata=run_input.metadata,
session_snapshot=run_input.session_snapshot,
on_exit=LayerExitSignals(
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
),
)
def build_cleanup_request(
self,
*,
session_snapshot: CompositorSessionSnapshot,
composition_layer_specs: list[CleanupLayerSpec],
idempotency_key: str | None = None,
metadata: dict[str, JsonValue] | None = None,
) -> CreateRunRequest:
"""Build a lifecycle-only cleanup request that replays the prior layers.
The agenton compositor enforces that the session snapshot's layer names
match the active composition in order, so cleanup must replay the same
non-plugin layer graph that produced the snapshot. Plugin layers
(``dify.plugin.llm``, ``dify.plugin.tools``) are excluded from both the
composition and the snapshot before submission because their configs
require credentials that are not persisted between runs.
"""
if not composition_layer_specs:
raise ValueError(
"build_cleanup_request requires composition_layer_specs; an empty "
"composition would fail the agent backend's snapshot validation."
)
request_metadata = dict(metadata or {})
request_metadata["agent_backend_lifecycle"] = "session_cleanup"
layers = [
RunLayerSpec(
name=spec.name,
type=spec.type,
deps=dict(spec.deps),
metadata=dict(spec.metadata),
config=spec.config,
)
for spec in composition_layer_specs
]
filtered_snapshot = _filter_snapshot_to_specs(session_snapshot, composition_layer_specs)
return CreateRunRequest(
composition=RunComposition(layers=layers),
purpose="workflow_node",
idempotency_key=idempotency_key,
metadata=request_metadata,
session_snapshot=filtered_snapshot,
on_exit=LayerExitSignals(default=ExitIntent.DELETE),
)
def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest:
"""Build a workflow Agent Node run request without defining another wire schema."""
layers: list[RunLayerSpec] = []
@ -121,21 +408,32 @@ class AgentBackendRunRequestBuilder:
config=PromptLayerConfig(user=run_input.user_prompt),
),
RunLayerSpec(
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
type=DIFY_PLUGIN_LAYER_TYPE_ID,
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=DifyPluginLayerConfig(
tenant_id=run_input.model.tenant_id,
plugin_id=run_input.model.plugin_id,
user_id=run_input.model.user_id,
),
config=run_input.execution_context,
),
]
)
if run_input.include_history:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_HISTORY_LAYER_ID,
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
metadata={**run_input.metadata, "origin": "agent_session_history"},
)
)
layers.extend(
[
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=DifyPluginLLMLayerConfig(
plugin_id=run_input.model.plugin_id,
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
@ -145,6 +443,29 @@ class AgentBackendRunRequestBuilder:
]
)
if run_input.tools is not None and run_input.tools.tools:
layers.append(
RunLayerSpec(
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=run_input.tools,
)
)
if run_input.include_shell:
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
# so the spec carries no deps; shellctl connection is server-injected.
layers.append(
RunLayerSpec(
name=DIFY_SHELL_LAYER_ID,
type=DIFY_SHELL_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.shell_config or DifyShellLayerConfig(),
)
)
if run_input.output is not None:
layers.append(
RunLayerSpec(
@ -153,7 +474,6 @@ class AgentBackendRunRequestBuilder:
metadata=run_input.metadata,
config=DifyOutputLayerConfig(
json_schema=run_input.output.json_schema,
name=run_input.output.name,
description=run_input.output.description,
strict=run_input.output.strict,
),
@ -162,7 +482,6 @@ class AgentBackendRunRequestBuilder:
return CreateRunRequest(
composition=RunComposition(layers=layers),
execution_context=run_input.execution_context,
purpose=run_input.purpose,
idempotency_key=run_input.idempotency_key,
metadata=run_input.metadata,

View File

@ -0,0 +1,135 @@
"""API-side client for the agent backend's read-only workspace file endpoints.
The agent backend exposes ``/workspaces/{session_id}/files{,/preview,/download}``
to inspect a shell-layer sandbox workspace. This thin synchronous client proxies
those reads for the console FS inspector and normalizes transport/HTTP failures
into the API backend's ``AgentBackendError`` boundary, preserving the backend's
status code and ``{code, message}`` detail so the controller can relay them.
"""
from __future__ import annotations
import base64
import binascii
from dataclasses import dataclass
from typing import Literal
import httpx
from pydantic import BaseModel
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
_DEFAULT_TIMEOUT_SECONDS = 30.0
class WorkspaceFileEntry(BaseModel):
"""One entry in a workspace directory listing."""
name: str
type: Literal["file", "dir", "symlink"]
size: int
mtime: int
class WorkspaceListResult(BaseModel):
"""Directory listing of a workspace path."""
path: str
entries: list[WorkspaceFileEntry]
truncated: bool
class WorkspacePreviewResult(BaseModel):
"""Inline preview of a workspace file."""
path: str
size: int
truncated: bool
binary: bool
text: str | None = None
@dataclass(frozen=True, slots=True)
class WorkspaceDownloadResult:
"""Decoded bytes of a workspace file for download."""
path: str
size: int
truncated: bool
content: bytes
class WorkspaceFilesBackendClient:
"""Synchronous proxy to the agent backend workspace file endpoints."""
def __init__(
self,
base_url: str,
*,
timeout: float = _DEFAULT_TIMEOUT_SECONDS,
transport: httpx.BaseTransport | None = None,
) -> None:
self._base_url = base_url.rstrip("/")
self._timeout = timeout
self._transport = transport
def list_files(self, session_id: str, path: str) -> WorkspaceListResult:
data = self._get(f"/workspaces/{session_id}/files", params={"path": path})
return WorkspaceListResult.model_validate(data)
def preview(self, session_id: str, path: str) -> WorkspacePreviewResult:
data = self._get(f"/workspaces/{session_id}/files/preview", params={"path": path})
return WorkspacePreviewResult.model_validate(data)
def download(self, session_id: str, path: str) -> WorkspaceDownloadResult:
data = self._get(f"/workspaces/{session_id}/files/download", params={"path": path})
encoded = data.get("content_base64")
if not isinstance(encoded, str):
raise AgentBackendHTTPError("agent backend download response missing content", status_code=502, detail=data)
try:
content = base64.b64decode(encoded, validate=True)
except (binascii.Error, ValueError) as exc:
raise AgentBackendHTTPError(
"agent backend returned undecodable download content", status_code=502, detail=str(exc)
) from exc
size = data.get("size")
return WorkspaceDownloadResult(
path=str(data.get("path", path)),
size=int(size) if isinstance(size, (int, float)) else len(content),
truncated=bool(data.get("truncated")),
content=content,
)
def _get(self, route: str, *, params: dict[str, str]) -> dict[str, object]:
url = f"{self._base_url}{route}"
try:
with httpx.Client(timeout=self._timeout, transport=self._transport, trust_env=False) as client:
response = client.get(url, params=params)
except httpx.HTTPError as exc:
raise AgentBackendTransportError(f"failed to reach agent backend workspace endpoint: {exc}") from exc
if response.status_code >= 400:
detail: object
try:
detail = response.json().get("detail", response.text)
except ValueError:
detail = response.text
raise AgentBackendHTTPError(
f"agent backend workspace request failed ({response.status_code})",
status_code=response.status_code,
detail=detail,
)
body = response.json()
if not isinstance(body, dict):
raise AgentBackendHTTPError(
"agent backend workspace response was not an object", status_code=502, detail=body
)
return body
__all__ = [
"WorkspaceDownloadResult",
"WorkspaceFileEntry",
"WorkspaceFilesBackendClient",
"WorkspaceListResult",
"WorkspacePreviewResult",
]

View File

@ -3,6 +3,13 @@ CLI command modules extracted from `commands.py`.
"""
from .account import create_tenant, reset_email, reset_password
from .data_migrate import data_migrate, legacy_model_types
from .data_migration import (
export_migration_data,
export_migration_data_template,
import_migration_data,
migration_data_wizard,
)
from .plugin import (
extract_plugins,
extract_unique_plugins,
@ -25,7 +32,12 @@ from .retention import (
restore_workflow_runs,
)
from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage
from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db
from .system import (
convert_to_agent_apps,
fix_app_site_missing,
reset_encrypt_key_pair,
upgrade_db,
)
from .vector import (
add_qdrant_index,
migrate_annotation_vector_database,
@ -44,18 +56,24 @@ __all__ = [
"clear_orphaned_file_records",
"convert_to_agent_apps",
"create_tenant",
"data_migrate",
"delete_archived_workflow_runs",
"export_app_messages",
"export_migration_data",
"export_migration_data_template",
"extract_plugins",
"extract_unique_plugins",
"file_usage",
"fix_app_site_missing",
"import_migration_data",
"install_plugins",
"install_rag_pipeline_plugins",
"legacy_model_types",
"migrate_annotation_vector_database",
"migrate_data_for_plugin",
"migrate_knowledge_vector_database",
"migrate_oss",
"migration_data_wizard",
"old_metadata_migration",
"remove_orphaned_files_on_storage",
"reset_email",

View File

@ -0,0 +1,179 @@
import io
import os
import sys
from contextlib import AbstractContextManager, nullcontext
from pathlib import Path
from typing import cast
import click
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelType
from services.legacy_model_type_migration import (
VALID_TABLE_NAMES,
LegacyModelTypeMigrationService,
load_tenant_ids_from_file,
)
_SUPPORTED_MODEL_TYPE_CHOICES = (
ModelType.LLM.value,
ModelType.TEXT_EMBEDDING.value,
ModelType.RERANK.value,
)
_DEFAULT_CONCURRENCY = os.cpu_count() or 1
def _normalize_multi_value_option(
values: tuple[str, ...],
*,
valid_values: tuple[str, ...],
option_name: str,
) -> tuple[str, ...]:
normalized_values: list[str] = []
seen_values: set[str] = set()
for value in values:
for item in value.split(","):
normalized_item = item.strip()
if not normalized_item:
continue
if normalized_item not in valid_values:
raise click.BadParameter(
f"invalid value '{normalized_item}'. valid values: {', '.join(valid_values)}",
param_hint=option_name,
)
if normalized_item in seen_values:
continue
seen_values.add(normalized_item)
normalized_values.append(normalized_item)
return tuple(normalized_values)
@click.group(
"data-migrate",
help="Online data migration commands.",
)
def data_migrate() -> None:
"""Namespace for production data migration commands."""
@click.command(
"legacy-model-types",
help=(
"Migrate legacy provider model_type values to canonical values. "
"Default is dry-run and emits JSON lines only. "
"If --tables includes provider_model_credentials, the command may also update "
"provider_models and load_balancing_model_configs references so merged credentials stay reachable."
),
)
@click.option(
"--apply",
is_flag=True,
default=False,
help="Apply the migration. Default is dry-run.",
)
@click.option(
"--tables",
"tables",
multiple=True,
type=str,
help=(
"Limit migration to specific tables. Accepts comma-separated values or repeated flags.\n"
"\n"
"Options: load_balancing_model_configs, provider_model_credentials, "
"provider_model_settings, provider_models, tenant_default_models.\n\n"
"When provider_model_credentials is selected, provider_models and "
"load_balancing_model_configs may also be updated for credential reference rewrites.\n"
"\n"
"If unspecified, all relevant tables are migrated."
),
)
@click.option(
"--model-types",
"model_types",
multiple=True,
type=str,
help=(
"Canonical model types to migrate. Accepts comma-separated values or repeated flags.\n"
"\n"
"Options: llm,text-embedding,rerank\n"
"\n"
"If unspecified, all relevant legacy model types are migrated."
),
)
@click.option(
"--tenant-id-file",
type=click.Path(exists=True, dir_okay=False, readable=True, resolve_path=True),
help="Optional file containing tenant ids, one per line.",
)
@click.option(
"--output",
type=click.Path(dir_okay=False, resolve_path=True, path_type=Path),
help=(
"Optional file path for JSON lines event logs. Defaults to stdout.\n"
"It's highly recommended to save the event logs to a file and preserve it for a period of time."
),
)
@click.option(
"--concurrency",
type=click.IntRange(min=1),
default=_DEFAULT_CONCURRENCY,
show_default=True,
help="Number of tenant-level worker threads to run in parallel.",
)
def legacy_model_types(
apply: bool,
tables: tuple[str, ...],
model_types: tuple[str, ...],
tenant_id_file: str | None,
output: Path | None,
concurrency: int = _DEFAULT_CONCURRENCY,
) -> None:
"""
Migrate legacy provider-related model_type values and emit JSON lines events.
"""
normalized_tables = _normalize_multi_value_option(
tables,
valid_values=VALID_TABLE_NAMES,
option_name="--tables",
)
normalized_model_types = _normalize_multi_value_option(
model_types,
valid_values=_SUPPORTED_MODEL_TYPE_CHOICES,
option_name="--model-types",
)
selected_model_types = (
tuple(ModelType.value_of(model_type) for model_type in normalized_model_types)
if normalized_model_types
else (
ModelType.LLM,
ModelType.TEXT_EMBEDDING,
ModelType.RERANK,
)
)
tenant_ids = load_tenant_ids_from_file(tenant_id_file) if tenant_id_file else None
output_context: AbstractContextManager[io.TextIOBase]
if output is None:
output_context = nullcontext(cast(io.TextIOBase, sys.stdout))
else:
try:
output_context = output.open("w", encoding="utf-8")
except OSError as exc:
raise click.ClickException(f"failed to open output file '{output}': {exc.strerror or exc}") from exc
with output_context as output_stream:
LegacyModelTypeMigrationService(
engine=db.engine,
apply=apply,
concurrency=concurrency,
output=cast(io.TextIOBase, output_stream),
tables=normalized_tables or None,
model_types=selected_model_types,
tenant_ids=tenant_ids,
).migrate()
data_migrate.add_command(legacy_model_types)

View File

@ -0,0 +1,754 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Any, cast
from uuid import UUID
import click
import sqlalchemy as sa
import yaml
from extensions.ext_database import db
from models import Tenant
from models.model import App
from models.tools import ApiToolProvider, MCPToolProvider, WorkflowToolProvider
from services.app_dsl_service import AppDslService
from services.data_migration.dependency_discovery_service import DependencyDiscoveryService
from services.data_migration.entities import (
DependencyKind,
ImportOptions,
MigrationDataError,
ReportContext,
ResourceReportItem,
)
from services.data_migration.export_service import ExportConfigParser, MigrationExportService
from services.data_migration.import_service import ImportRequest, MigrationImportService
from services.data_migration.package_service import MigrationPackageService
from services.data_migration.report_service import MigrationReportService
ID_STRATEGY_CHOICES = ["preserve-id", "generate-new-id"]
CONFLICT_STRATEGY_CHOICES = ["fail", "skip", "update"]
SUPPORTED_WIZARD_APP_MODES = ["workflow", "advanced-chat"]
WizardToolMap = dict[str, dict[str, str | None]]
WizardToolSelection = dict[str, list[str]]
def _scripted_export_template() -> dict[str, Any]:
return {
"source_tenant": {
"mode": "single",
"id": "",
"name": "admin's Workspace",
},
"apps": {
"modes": ["workflow", "advanced-chat"],
"ids": [],
"all": True,
},
"include_referenced_tools": True,
"additional_tools": {
"api_tools": [],
"workflow_tools": [],
"mcp_tools": [],
},
"include_secrets": False,
"import_options": {
"create_app_api_token_on_import": False,
"id_strategy": "preserve-id",
"conflict_strategy": "fail",
},
}
@click.command("app-migration-template", help="Print or write a scripted export config JSON template.")
@click.option(
"--output",
"output_file",
required=False,
type=click.Path(dir_okay=False),
help="Path to write the export config JSON template. Prints to stdout when omitted.",
)
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
def export_migration_data_template(output_file: str | None, overwrite: bool) -> None:
template_json = json.dumps(_scripted_export_template(), indent=2, ensure_ascii=False) + "\n"
if output_file is None:
click.echo(template_json, nl=False)
return
path = Path(output_file)
if path.exists() and not overwrite:
raise click.ClickException(f"Output file already exists: {output_file}")
path.write_text(template_json)
click.echo(click.style(f"Output written to {output_file}", fg="green"))
@click.command("export-app-migration", help="Export workflow migration data to a versioned JSON package.")
@click.option(
"--input",
"input_file",
required=False,
type=click.Path(exists=True, dir_okay=False),
help="Path to export config JSON.",
)
@click.option(
"--output",
"output_file",
required=False,
type=click.Path(dir_okay=False),
help="Path to migration package JSON.",
)
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
def export_migration_data(input_file: str | None, output_file: str | None, overwrite: bool) -> None:
try:
_require_options(("--input", input_file), ("--output", output_file))
assert input_file is not None
assert output_file is not None
raw_config = _load_json_object(input_file, "Export config")
selection = ExportConfigParser().parse(raw_config)
result = MigrationExportService().export(selection)
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
click.echo(click.style(f"Output written to {output_file}", fg="green"))
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
except MigrationDataError as exc:
raise click.ClickException(str(exc)) from exc
@click.command("import-app-migration", help="Import a versioned migration data package.")
@click.option(
"--input",
"input_file",
required=False,
type=click.Path(exists=True, dir_okay=False),
help="Path to migration package JSON.",
)
@click.option("--target-tenant", default=None, help="Target tenant/workspace name. Overrides package metadata.")
@click.option("--operator-email", default=None, help="Operator account email in the target tenant.")
@click.option(
"--id-strategy",
default=None,
type=click.Choice(ID_STRATEGY_CHOICES),
help="Override package ID strategy.",
)
@click.option(
"--conflict-strategy",
default=None,
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
help="Override package conflict strategy.",
)
@click.option(
"--create-app-api-token-on-import/--no-create-app-api-token-on-import",
default=None,
help="Override package app API token creation behavior.",
)
def import_migration_data(
input_file: str | None,
target_tenant: str | None,
operator_email: str | None,
id_strategy: str | None,
conflict_strategy: str | None,
create_app_api_token_on_import: bool | None,
) -> None:
try:
_require_options(("--input", input_file))
assert input_file is not None
package = MigrationPackageService().load_package(input_file)
result = MigrationImportService().import_package(
ImportRequest(
package=package,
cli_target_tenant=target_tenant,
operator_email=operator_email,
options_override=_build_options_override(
package.metadata.import_options,
id_strategy=id_strategy,
conflict_strategy=conflict_strategy,
create_app_api_token_on_import=create_app_api_token_on_import,
),
)
)
_render_report(result.report_items, context=result.report_context)
except MigrationDataError as exc:
raise click.ClickException(str(exc)) from exc
def parse_index_selection(raw: str, values: list[str]) -> list[str]:
normalized = raw.strip().lower()
if normalized == "all":
return values
selected: list[str] = []
for part in raw.split(","):
stripped = part.strip()
if not stripped:
continue
try:
index = int(stripped)
except ValueError as exc:
raise click.ClickException(f"Selection must be 'all' or comma-separated numbers: {raw}") from exc
if index < 1 or index > len(values):
raise click.ClickException(f"Selection index out of range: {index}")
selected.append(values[index - 1])
return list(dict.fromkeys(selected))
def _print_wizard_step(title: str) -> None:
click.echo("")
click.echo(f"==== {title} ====")
def _print_wizard_substep(title: str) -> None:
click.echo("")
click.echo(f"-- {title} --")
@click.command("app-migration-wizard", help="Interactively export workflow migration data.")
def migration_data_wizard() -> None:
try:
tenant = _prompt_source_tenant()
apps = _eligible_apps_for_tenant(tenant.id)
app_ids = _prompt_app_ids(apps)
_print_wizard_step("Referenced Tools")
include_referenced_tools = click.confirm(
"Automatically export tools referenced by selected apps? [y/n, default: y]",
default=True,
show_default=False,
)
auto_tools = _discover_auto_tools([app for app in apps if app.id in set(app_ids)], include_referenced_tools)
auto_tools = _resolve_auto_tool_names(tenant.id, auto_tools)
_print_auto_tools(auto_tools)
additional_tools = _prompt_additional_tools(tenant.id, auto_tools)
include_secrets, create_tokens, id_strategy, conflict_strategy = _prompt_import_options()
_print_wizard_step("Output")
output_file, overwrite = _prompt_output_file()
selection = ExportConfigParser().parse(
{
"source_tenant": {"mode": "single", "id": tenant.id, "name": tenant.name},
"apps": {"ids": app_ids, "all": False},
"include_referenced_tools": include_referenced_tools,
"additional_tools": additional_tools,
"include_secrets": include_secrets,
"import_options": {
"create_app_api_token_on_import": create_tokens,
"id_strategy": id_strategy,
"conflict_strategy": conflict_strategy,
},
}
)
_confirm_wizard_summary(
tenant_name=tenant.name,
app_names=[app.name for app in apps if app.id in set(app_ids)],
auto_tools=auto_tools,
additional_tools=additional_tools,
manual_labels=_selected_tool_labels_for_tenant(tenant.id, additional_tools),
include_referenced_tools=include_referenced_tools,
include_secrets=include_secrets,
create_tokens=create_tokens,
id_strategy=id_strategy,
conflict_strategy=conflict_strategy,
output_file=output_file,
)
result = MigrationExportService().export(selection)
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
click.echo(click.style(f"Output written to {output_file}", fg="green"))
_print_wizard_step("Report")
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
except MigrationDataError as exc:
raise click.ClickException(str(exc)) from exc
def _load_json_object(path: str, label: str) -> dict[str, Any]:
try:
with Path(path).open(encoding="utf-8") as file:
raw = json.load(file)
except json.JSONDecodeError as exc:
raise MigrationDataError(f"{label} JSON is invalid: {exc.msg}") from exc
if not isinstance(raw, dict):
raise MigrationDataError(f"{label} JSON must be an object.")
return raw
def _require_options(*options: tuple[str, object | None]) -> None:
missing_options = [name for name, value in options if value is None]
if missing_options:
raise click.UsageError(f"Missing option(s): {', '.join(missing_options)}.")
def _build_options_override(
package_options: ImportOptions,
*,
id_strategy: str | None,
conflict_strategy: str | None,
create_app_api_token_on_import: bool | None,
) -> ImportOptions | None:
if id_strategy is None and conflict_strategy is None and create_app_api_token_on_import is None:
return None
return ImportOptions.from_mapping(
{
"id_strategy": id_strategy or package_options.id_strategy,
"conflict_strategy": conflict_strategy or package_options.conflict_strategy,
"create_app_api_token_on_import": (
create_app_api_token_on_import
if create_app_api_token_on_import is not None
else package_options.create_app_api_token_on_import
),
}
)
def _prompt_source_tenant() -> Tenant:
tenants = list(db.session.scalars(sa.select(Tenant).order_by(Tenant.name.asc())).all())
if not tenants:
raise MigrationDataError("No tenants found.")
_print_wizard_step("Source Tenant")
click.echo("Source tenants:")
for index, tenant in enumerate(tenants, 1):
click.echo(f"{index}. {tenant.name} ({tenant.id})")
tenant_index = click.prompt("Select one source tenant by number", type=int, default=1, show_default=True)
if tenant_index < 1 or tenant_index > len(tenants):
raise click.ClickException(f"Selection index out of range: {tenant_index}")
return tenants[tenant_index - 1]
def _eligible_apps_for_tenant(tenant_id: str) -> list[App]:
return list(
db.session.scalars(
sa.select(App)
.where(App.tenant_id == tenant_id, App.mode.in_(SUPPORTED_WIZARD_APP_MODES))
.order_by(App.name.asc())
).all()
)
def _prompt_app_ids(apps: list[App]) -> list[str]:
if not apps:
raise MigrationDataError("No workflow or advanced-chat apps found for the selected tenant.")
_print_wizard_step("App Selection")
click.echo("Currently supported app types: workflow and chatflow.")
click.echo("Workflow/chatflow apps:")
for index, app in enumerate(apps, 1):
mode = app.mode.value if hasattr(app.mode, "value") else app.mode
click.echo(f"{index}. {app.name} [{mode}] ({app.id})")
app_ids = parse_index_selection(
click.prompt("Select apps by number, comma-separated numbers, or all", default="all"),
[app.id for app in apps],
)
selected_apps = [app for app in apps if app.id in set(app_ids)]
click.echo("Selected apps:")
for app in selected_apps:
click.echo(f"- {app.name} ({app.id})")
return app_ids
def _prompt_import_options() -> tuple[bool, bool, str, str]:
_print_wizard_step("Import Options")
_print_wizard_substep("Secrets")
click.echo("Secrets include workflow/app DSL secret values, custom API tool credentials,")
click.echo("and full MCP provider connection data such as server URL, headers, authentication, and tool list.")
click.echo("If you choose no, credentials are omitted or masked,")
click.echo("and MCP providers are exported as dependency metadata only.")
click.echo("Treat the output JSON as sensitive if you choose yes.")
include_secrets = click.confirm(
"Include secrets in output JSON? [y/n, default: n]",
default=False,
show_default=False,
)
_print_wizard_substep("App API Tokens")
click.echo("When enabled, import will create an app API token if the imported app has none,")
click.echo("or reuse an existing app API token if one already exists.")
create_tokens = click.confirm(
"Create or reuse app API tokens during import? [y/n, default: n]",
default=False,
show_default=False,
)
_print_wizard_substep("ID Strategy")
click.echo("ID strategy controls whether imported app and tool IDs preserve source IDs")
click.echo("or use target-generated IDs.")
click.echo("preserve-id: keep source IDs where the target service supports it.")
click.echo("generate-new-id: let the target environment generate new IDs and rewrite references via mapping.")
id_strategy = click.prompt(
"Import ID strategy. Enter one of: preserve-id, generate-new-id",
type=click.Choice(ID_STRATEGY_CHOICES),
default="preserve-id",
show_default=True,
)
_print_wizard_substep("Conflict Strategy")
click.echo("Conflict strategy controls what import does when a target resource already exists.")
click.echo("fail: stop at the first conflict; previously committed resources are not rolled back.")
click.echo("skip: keep the existing target resource and skip importing that resource.")
click.echo("update: update the existing target resource in place.")
conflict_strategy = click.prompt(
"Import conflict strategy. Enter one of: fail, skip, update",
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
default="update",
show_default=True,
)
return include_secrets, create_tokens, id_strategy, conflict_strategy
def _discover_auto_tools(apps: list[App], include_referenced_tools: bool) -> WizardToolMap:
auto_tools: WizardToolMap = {"api_tools": {}, "workflow_tools": {}, "mcp_tools": {}}
if not include_referenced_tools:
return auto_tools
discovery_service = DependencyDiscoveryService()
for app in apps:
dsl_content = AppDslService.export_dsl(app_model=app, include_secret=False)
raw_dsl = yaml.safe_load(dsl_content) if dsl_content else {}
dsl = raw_dsl if isinstance(raw_dsl, dict) else {}
for dependency in discovery_service.discover_from_dsl(dsl):
if dependency.kind == DependencyKind.API_TOOL:
auto_tools["api_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
elif dependency.kind == DependencyKind.WORKFLOW_TOOL:
auto_tools["workflow_tools"][dependency.provider_name or dependency.provider_id] = (
dependency.provider_id
)
elif dependency.kind == DependencyKind.MCP_TOOL:
auto_tools["mcp_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
return auto_tools
def _resolve_auto_tool_names(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolMap:
return {
"api_tools": _resolve_api_tool_names(tenant_id, auto_tools["api_tools"]),
"workflow_tools": _resolve_workflow_tool_names(tenant_id, auto_tools["workflow_tools"]),
"mcp_tools": _resolve_mcp_tool_names(tenant_id, auto_tools["mcp_tools"]),
}
def _resolve_api_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
resolved: dict[str, str | None] = {}
for name, identifier in tools.items():
predicates = [ApiToolProvider.name == name]
if _is_uuid_string(identifier):
predicates.append(ApiToolProvider.id == identifier)
provider = db.session.scalar(
sa.select(ApiToolProvider).where(
ApiToolProvider.tenant_id == tenant_id,
sa.or_(*predicates),
)
)
resolved[provider.name if provider else name] = provider.id if provider else identifier
return resolved
def _resolve_workflow_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
resolved: dict[str, str | None] = {}
for name, identifier in tools.items():
predicates = [WorkflowToolProvider.name == name]
if _is_uuid_string(identifier):
predicates.append(WorkflowToolProvider.id == identifier)
provider = db.session.scalar(
sa.select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id,
sa.or_(*predicates),
)
)
resolved[provider.name if provider else name] = provider.id if provider else identifier
return resolved
def _resolve_mcp_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
resolved: dict[str, str | None] = {}
for name, identifier in tools.items():
predicates = [MCPToolProvider.name == name]
if identifier:
predicates.append(MCPToolProvider.server_identifier == identifier)
if _is_uuid_string(identifier):
predicates.append(MCPToolProvider.id == identifier)
provider = db.session.scalar(
sa.select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
sa.or_(*predicates),
)
)
resolved[provider.name if provider else name] = provider.id if provider else identifier
return resolved
def _is_uuid_string(value: str | None) -> bool:
if not value:
return False
try:
UUID(value)
except ValueError:
return False
return True
def _print_auto_tools(auto_tools: WizardToolMap) -> None:
_print_wizard_step("Automatically Discovered Tools")
click.echo("Automatically discovered tools:")
_print_auto_tool_category("Custom API tools", auto_tools["api_tools"])
_print_auto_tool_category("Workflow tools", auto_tools["workflow_tools"])
_print_auto_tool_category("MCP tools", auto_tools["mcp_tools"])
def _print_auto_tool_category(label: str, values: dict[str, str | None]) -> None:
click.echo(label)
if not values:
click.echo("- none")
return
for name, identifier in sorted(values.items()):
click.echo(f"- {_format_tool_name_id(name, identifier)}")
def _prompt_additional_tools(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolSelection:
selections: WizardToolSelection = {"api_tools": [], "workflow_tools": [], "mcp_tools": []}
_print_wizard_step("Additional Tools")
if not click.confirm(
"Export additional tools manually? [y/n, default: n]",
default=False,
show_default=False,
):
_print_final_tool_selection(auto_tools, selections, {})
return selections
manual_labels: dict[str, str] = {}
api_tool_options = [
(tool.name, tool.name, tool.id)
for tool in db.session.scalars(
sa.select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).order_by(ApiToolProvider.name)
).all()
]
selections["api_tools"] = _prompt_tool_category(
"Custom API tools",
api_tool_options,
auto_tools=auto_tools["api_tools"],
)
manual_labels.update(_selected_tool_labels(api_tool_options, selections["api_tools"]))
workflow_tool_options = [
(tool.id, tool.name, tool.id)
for tool in db.session.scalars(
sa.select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id)
.order_by(WorkflowToolProvider.name)
).all()
]
selections["workflow_tools"] = _prompt_tool_category(
"Workflow tools",
workflow_tool_options,
auto_tools=auto_tools["workflow_tools"],
)
manual_labels.update(_selected_tool_labels(workflow_tool_options, selections["workflow_tools"]))
mcp_tool_options = [
(tool.id, tool.name, tool.server_identifier)
for tool in db.session.scalars(
sa.select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
).all()
]
selections["mcp_tools"] = _prompt_tool_category(
"MCP tools",
mcp_tool_options,
auto_tools=auto_tools["mcp_tools"],
)
manual_labels.update(_selected_tool_labels(mcp_tool_options, selections["mcp_tools"]))
_print_final_tool_selection(auto_tools, selections, manual_labels)
return selections
def _selected_tool_labels_for_tenant(tenant_id: str, selected_tools: WizardToolSelection) -> dict[str, str]:
labels: dict[str, str] = {}
if selected_tools["api_tools"]:
labels.update(
_selected_tool_labels(
[
(tool.name, tool.name, tool.id)
for tool in db.session.scalars(
sa.select(ApiToolProvider)
.where(ApiToolProvider.tenant_id == tenant_id)
.order_by(ApiToolProvider.name)
).all()
],
selected_tools["api_tools"],
)
)
if selected_tools["workflow_tools"]:
labels.update(
_selected_tool_labels(
[
(tool.id, tool.name, tool.id)
for tool in db.session.scalars(
sa.select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id)
.order_by(WorkflowToolProvider.name)
).all()
],
selected_tools["workflow_tools"],
)
)
if selected_tools["mcp_tools"]:
labels.update(
_selected_tool_labels(
[
(tool.id, tool.name, tool.server_identifier)
for tool in db.session.scalars(
sa.select(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id)
.order_by(MCPToolProvider.name)
).all()
],
selected_tools["mcp_tools"],
)
)
return labels
def _selected_tool_labels(options: list[tuple[str, str, str]], selected_values: list[str]) -> dict[str, str]:
selected = set(selected_values)
return {value: _format_tool_name_id(name, detail) for value, name, detail in options if value in selected}
def _prompt_tool_category(
label: str,
options: list[tuple[str, str, str]],
*,
auto_tools: dict[str, str | None],
) -> list[str]:
if not options:
click.echo(f"{label}: none")
return []
_print_wizard_step(label)
for index, (value, name, detail) in enumerate(options, 1):
marker = "[auto]" if _is_auto_tool(value, name, detail, auto_tools) else "[ ]"
click.echo(f"{index}. {marker} {name} ({detail})")
raw = click.prompt(
f"Select {label.lower()} by number, comma-separated numbers, all, or empty",
default="",
show_default=cast(Any, "empty"),
)
if not raw.strip():
return []
return parse_index_selection(raw, [value for value, _, _ in options])
def _is_auto_tool(value: str, name: str, detail: str, auto_tools: dict[str, str | None]) -> bool:
return name in auto_tools or value in auto_tools or value in auto_tools.values() or detail in auto_tools.values()
def _print_final_tool_selection(
auto_tools: WizardToolMap,
additional_tools: WizardToolSelection,
manual_labels: dict[str, str],
) -> None:
_print_wizard_step("Final Tool Selection")
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
def _print_tool_selection_body(
auto_tools: WizardToolMap,
additional_tools: WizardToolSelection,
manual_labels: dict[str, str],
) -> None:
click.echo("Final tools to export:")
_print_final_tool_category(
"Custom API tools",
auto_tools["api_tools"],
additional_tools["api_tools"],
manual_labels,
)
_print_final_tool_category(
"Workflow tools",
auto_tools["workflow_tools"],
additional_tools["workflow_tools"],
manual_labels,
)
_print_final_tool_category("MCP tools", auto_tools["mcp_tools"], additional_tools["mcp_tools"], manual_labels)
def _print_final_tool_category(
label: str,
auto_tools: dict[str, str | None],
manual_values: list[str],
manual_labels: dict[str, str],
) -> None:
click.echo(label)
lines = [f"- [auto] {_format_tool_name_id(name, identifier)}" for name, identifier in sorted(auto_tools.items())]
auto_identifiers = {identifier for identifier in auto_tools.values() if identifier}
lines.extend(
f"- [manual] {manual_labels.get(value, value)}"
for value in manual_values
if value not in auto_tools and value not in auto_identifiers
)
if not lines:
click.echo("- none")
return
for line in lines:
click.echo(line)
def _format_tool_name_id(name: str, identifier: str | None) -> str:
if identifier and identifier != name:
return f"{name}: {identifier}"
return name
def _confirm_wizard_summary(
*,
tenant_name: str,
app_names: list[str],
auto_tools: WizardToolMap,
additional_tools: WizardToolSelection,
manual_labels: dict[str, str],
include_referenced_tools: bool,
include_secrets: bool,
create_tokens: bool,
id_strategy: str,
conflict_strategy: str,
output_file: str,
) -> None:
_print_wizard_step("Summary")
click.echo("Migration export summary:")
click.echo(f"source tenant: {tenant_name}")
click.echo(f"selected apps: {len(app_names)}")
for app_name in app_names:
click.echo(f"- {app_name}")
click.echo(f"auto referenced tools: {str(include_referenced_tools).lower()}")
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
click.echo(f"include secrets: {str(include_secrets).lower()}")
click.echo(f"create app api token on import: {str(create_tokens).lower()}")
click.echo(f"id strategy: {id_strategy}")
click.echo(f"conflict strategy: {conflict_strategy}")
click.echo(f"output path: {output_file}")
if not click.confirm("Write migration package? [y/n, default: y]", default=True, show_default=False):
raise click.Abort()
def _prompt_output_file() -> tuple[str, bool]:
default_output = f"migration-data-{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
output_file = click.prompt("Output path", default=default_output, show_default=True)
if output_file.lower() in {"y", "yes", "n", "no"}:
raise click.ClickException("Output path must be a file path. Press Enter to use the default path.")
overwrite = False
if Path(output_file).exists():
overwrite = click.confirm(
"Output file exists. Overwrite? [y/n, default: n]",
default=False,
show_default=False,
)
if not overwrite:
raise click.ClickException(f"Output file already exists: {output_file}")
return output_file, overwrite
def _with_output_path(context: ReportContext | None, output_path: str) -> ReportContext:
if context is None:
return ReportContext(output_path=output_path)
return ReportContext(
output_path=output_path,
source_scope=context.source_scope,
selected_app_count=context.selected_app_count,
include_secrets=context.include_secrets,
target_tenant=context.target_tenant,
operator_email=context.operator_email,
app_api_tokens_created=context.app_api_tokens_created,
app_api_tokens_reused=context.app_api_tokens_reused,
id_mapping_count=context.id_mapping_count,
id_mappings=context.id_mappings,
)
def _render_report(report_items: list[ResourceReportItem], *, context: ReportContext | None = None) -> None:
for line in MigrationReportService().render(report_items, context=context):
click.echo(line)

View File

@ -30,7 +30,7 @@ def vdb_migrate(scope: str):
def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
Migrate annotation data to target vector database.
"""
click.echo(click.style("Starting annotation data migration.", fg="green"))
create_count = 0
@ -140,7 +140,7 @@ def migrate_annotation_vector_database():
def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
Migrate vector database data to target vector database.
"""
click.echo(click.style("Starting vector database migration.", fg="green"))
create_count = 0

View File

@ -29,6 +29,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
@override
def __call__(self) -> dict[str, Any]:
current_state = self.current_state
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")

View File

@ -1,3 +1,5 @@
from typing import Literal
from pydantic import Field
from pydantic_settings import BaseSettings
@ -23,7 +25,7 @@ class DeploymentConfig(BaseSettings):
default=False,
)
EDITION: str = Field(
EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
default="SELF_HOSTED",
)

View File

@ -21,3 +21,13 @@ class AgentBackendConfig(BaseSettings):
description="Scenario used by the fake Agent backend client.",
default="success",
)
AGENT_SHELL_ENABLED: bool = Field(
description=(
"Inject the dify.shell layer (sandboxed bash workspace) into Agent runs. "
"Requires the agent backend to be wired with a shellctl entrypoint; keep it "
"off until shellctl is deployed, otherwise every agent run that includes the "
"shell layer will fail."
),
default=False,
)

View File

@ -525,6 +525,44 @@ class HttpConfig(BaseSettings):
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
OPENAPI_ENABLED: bool = Field(
description=(
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
"programmatic clients. Set to true to activate; disabled by default."
),
validation_alias=AliasChoices("OPENAPI_ENABLED"),
default=False,
)
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
description=(
"Comma-separated allowlist for /openapi/v1/* CORS. "
"Default empty = same-origin only. Browser-cookie routes within "
"the group reject cross-origin OPTIONS regardless of this list."
),
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
default="",
)
@computed_field
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
description=(
"Comma-separated client_id values accepted at "
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
"without code changes. Unknown client_id returns 400 unsupported_client."
),
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
default="difyctl",
)
@computed_field # type: ignore[misc]
@property
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
)
@ -900,6 +938,22 @@ class AuthConfig(BaseSettings):
default=86400,
)
ENABLE_OAUTH_BEARER: bool = Field(
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
default=True,
)
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
default=60,
)
DEVICE_FLOW_APPROVE_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
description="Max device-flow approve requests per session per hour on /openapi/oauth/device/approve.",
default=10,
)
class ModerationConfig(BaseSettings):
"""
@ -1186,6 +1240,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable scheduled workflow run cleanup task",
default=False,
)
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
default=True,
)
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
description="Days to retain revoked OAuth access-token rows before deletion.",
default=30,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,

View File

@ -41,3 +41,21 @@ class MilvusConfig(BaseSettings):
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
default=None,
)
MILVUS_SECURE: bool = Field(
description="Enable TLS for the Milvus connection (one-way TLS). When True, the client uses gRPC over TLS "
"and verifies the server certificate. Equivalent to passing secure=True to pymilvus.",
default=False,
)
MILVUS_SERVER_PEM_PATH: str | None = Field(
description="Filesystem path inside the container to the Milvus server certificate (PEM). Mount this via "
"a Kubernetes secret. Used as pymilvus's server_pem_path when MILVUS_SECURE is True.",
default=None,
)
MILVUS_SERVER_NAME: str | None = Field(
description="Server name (TLS SNI / certificate CN or SAN) to verify against the Milvus server certificate. "
"Required when MILVUS_SERVER_PEM_PATH is set.",
default=None,
)

View File

@ -81,4 +81,15 @@ default_app_templates: Mapping[AppMode, Mapping] = {
},
},
},
# agent default mode (new Agent App type). The runtime model / prompt / tools
# come from the bound Agent Soul snapshot, so no model_config is seeded in the
# template; create_app still creates a model-less app_model_config row to hold
# app-level presentation features (opener, follow-up, citations, ...).
AppMode.AGENT: {
"app": {
"mode": AppMode.AGENT,
"enable_site": True,
"enable_api": True,
},
},
}

View File

@ -1,10 +1,40 @@
import json
from pydantic import BaseModel, JsonValue
from pydantic import BaseModel, Field, JsonValue
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
"decision": "approve",
"attachment": {
"transfer_method": "local_file",
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
"type": "document",
},
"attachments": [
{
"transfer_method": "local_file",
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
"type": "document",
},
{
"transfer_method": "remote_url",
"url": "https://example.com/report.pdf",
"type": "document",
},
],
}
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
inputs: dict[str, JsonValue] = Field(
description=(
"Submitted human input values keyed by output variable name. "
"Use a string for paragraph or select input values, a file mapping for file inputs, "
"and a list of file mappings for file-list inputs. Local file mappings use "
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
"`transfer_method=remote_url` with `url` or `remote_url`."
),
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
)
action: str

View File

@ -6,10 +6,11 @@ These helpers keep that translation centralized so models registered through
`register_schema_models` emit resolvable Swagger 2.0 references.
"""
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from enum import StrEnum
from typing import Any, Literal, NotRequired, TypedDict
from typing import Any, Literal, NotRequired, Protocol, TypedDict
from flask import request
from flask_restx import Namespace
from pydantic import BaseModel, TypeAdapter
@ -35,6 +36,14 @@ QueryParamDoc = TypedDict(
},
)
JsonResponseWithStatus = tuple[dict[str, Any], int]
class QueryArgs(Protocol):
def to_dict(self, flat: bool = True) -> dict[str, str]: ...
def getlist(self, key: str) -> list[str]: ...
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
@ -167,6 +176,58 @@ def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
return params
def query_params_from_request[ModelT: BaseModel](
model: type[ModelT],
*,
list_fields: Iterable[str] = (),
args: QueryArgs | None = None,
use_defaults_for_malformed_ints: bool = False,
) -> ModelT:
"""Validate query args with Pydantic while preserving Flask query parsing behavior.
Repeated params need explicit ``getlist()`` handling because Werkzeug's
``to_dict()`` keeps only one value. For malformed scalar integers, Flask's
For endpoints migrated from ``request.args.get(..., type=int, default=...)``,
set ``use_defaults_for_malformed_ints`` to preserve Flask's fallback to
defaults for malformed optional integer params.
"""
query_args = args or request.args
params: dict[str, Any] = query_args.to_dict()
for field_name in list_fields:
params[field_name] = query_args.getlist(field_name)
if use_defaults_for_malformed_ints:
_drop_malformed_defaulted_integer_params(model, params)
return model.model_validate(params)
def _drop_malformed_defaulted_integer_params(model: type[BaseModel], params: dict[str, Any]) -> None:
properties = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0).get("properties", {})
if not isinstance(properties, Mapping):
return
for name, value in list(params.items()):
if not isinstance(value, str):
continue
field = model.model_fields.get(name)
if field is None or field.is_required():
continue
property_schema = properties.get(name)
if not isinstance(property_schema, Mapping):
continue
if _nullable_property_schema(property_schema).get("type") != "integer":
continue
try:
int(value)
except ValueError:
params.pop(name)
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
param_schema = _nullable_property_schema(property_schema)
param_doc: QueryParamDoc = {"in": "query", "required": required}
@ -239,6 +300,7 @@ __all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"get_or_create_model",
"query_params_from_model",
"query_params_from_request",
"register_enum_models",
"register_response_schema_model",
"register_response_schema_models",

View File

@ -51,6 +51,9 @@ from .agent import roster as agent_roster
from .app import (
advanced_prompt_template,
agent,
agent_app_access,
agent_app_feature,
agent_app_workspace,
annotation,
app,
audio,
@ -68,6 +71,7 @@ from .app import (
workflow_app_log,
workflow_comment,
workflow_draft_variable,
workflow_node_output_inspector,
workflow_run,
workflow_statistic,
workflow_trigger,
@ -145,6 +149,9 @@ __all__ = [
"activate",
"advanced_prompt_template",
"agent",
"agent_app_access",
"agent_app_feature",
"agent_app_workspace",
"agent_composer",
"agent_providers",
"agent_roster",
@ -218,6 +225,7 @@ __all__ = [
"workflow_app_log",
"workflow_comment",
"workflow_draft_variable",
"workflow_node_output_inspector",
"workflow_run",
"workflow_statistic",
"workflow_trigger",

View File

@ -1,153 +1,229 @@
from flask_restx import Resource
from controllers.common.schema import register_schema_models
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from fields.agent_fields import (
AgentAppComposerResponse,
AgentComposerCandidatesResponse,
AgentComposerImpactResponse,
AgentComposerValidateResponse,
WorkflowAgentComposerResponse,
)
from libs.helper import dump_response
from libs.login import login_required
from models.model import App, AppMode
from services.agent.composer_service import AgentComposerService
from services.agent.composer_validator import ComposerConfigValidator
from services.entities.agent_entities import ComposerSavePayload
register_schema_models(console_ns, ComposerSavePayload)
register_response_schema_models(
console_ns,
AgentAppComposerResponse,
AgentComposerCandidatesResponse,
AgentComposerImpactResponse,
AgentComposerValidateResponse,
WorkflowAgentComposerResponse,
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer")
class WorkflowAgentComposerApi(Resource):
@console_ns.response(
200, "Workflow agent composer state", console_ns.models[WorkflowAgentComposerResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def get(self, app_model, node_id: str):
_, tenant_id = current_account_with_tenant()
return AgentComposerService.load_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App, node_id: str):
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.load_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
),
)
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Workflow agent composer saved", console_ns.models[WorkflowAgentComposerResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def put(self, app_model, node_id: str):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account.id,
payload=payload,
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account_id,
payload=payload,
),
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/validate")
class WorkflowAgentComposerValidateApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Workflow agent composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def post(self, app_model, node_id: str):
def post(self, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
ComposerConfigValidator.validate_save_payload(payload)
return {"result": "success", "errors": []}
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/candidates")
class WorkflowAgentComposerCandidatesApi(Resource):
@console_ns.response(
200, "Workflow agent composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def get(self, app_model, node_id: str):
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
def get(self, app_model: App, node_id: str):
return dump_response(
AgentComposerCandidatesResponse,
AgentComposerService.get_workflow_candidates(app_id=app_model.id),
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/impact")
class WorkflowAgentComposerImpactApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(200, "Workflow agent composer impact", console_ns.models[AgentComposerImpactResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def post(self, app_model, node_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
if not current_snapshot_id:
return {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
return AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id)
return dump_response(
AgentComposerImpactResponse, {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
)
return dump_response(
AgentComposerImpactResponse,
AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id),
)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/save-to-roster")
class WorkflowAgentComposerSaveToRosterApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Workflow agent composer saved to roster", console_ns.models[WorkflowAgentComposerResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def post(self, app_model, node_id: str):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account.id,
payload=payload,
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.save_workflow_composer(
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account_id,
payload=payload,
),
)
@console_ns.route("/apps/<uuid:app_id>/agent-composer")
class AgentAppComposerApi(Resource):
@console_ns.response(200, "Agent app composer state", console_ns.models[AgentAppComposerResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model):
_, tenant_id = current_account_with_tenant()
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
return dump_response(
AgentAppComposerResponse,
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
)
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(200, "Agent app composer saved", console_ns.models[AgentAppComposerResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model()
def put(self, app_model):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return AgentComposerService.save_agent_app_composer(
tenant_id=tenant_id,
app_id=app_model.id,
account_id=account.id,
payload=payload,
return dump_response(
AgentAppComposerResponse,
AgentComposerService.save_agent_app_composer(
tenant_id=tenant_id,
app_id=app_model.id,
account_id=account_id,
payload=payload,
),
)
@console_ns.route("/apps/<uuid:app_id>/agent-composer/validate")
class AgentAppComposerValidateApi(Resource):
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
@console_ns.response(
200, "Agent app composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model()
def post(self, app_model):
def post(self, app_model: App):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
ComposerConfigValidator.validate_save_payload(payload)
return {"result": "success", "errors": []}
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
@console_ns.route("/apps/<uuid:app_id>/agent-composer/candidates")
class AgentAppComposerCandidatesApi(Resource):
@console_ns.response(
200, "Agent app composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model):
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)
def get(self, app_model: App):
return dump_response(
AgentComposerCandidatesResponse,
AgentComposerService.get_agent_app_candidates(app_id=app_model.id),
)

View File

@ -4,11 +4,25 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from fields.agent_fields import (
AgentConfigSnapshotDetailResponse,
AgentConfigSnapshotListResponse,
AgentInviteOptionsResponse,
AgentRosterListResponse,
AgentRosterResponse,
)
from libs.helper import dump_response
from libs.login import login_required
from services.agent.roster_service import AgentRosterService
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
@ -29,6 +43,14 @@ register_schema_models(
RosterAgentUpdatePayload,
RosterListQuery,
)
register_response_schema_models(
console_ns,
AgentConfigSnapshotDetailResponse,
AgentConfigSnapshotListResponse,
AgentInviteOptionsResponse,
AgentRosterListResponse,
AgentRosterResponse,
)
def _agent_roster_service() -> AgentRosterService:
@ -37,96 +59,130 @@ def _agent_roster_service() -> AgentRosterService:
@console_ns.route("/agents")
class AgentRosterListApi(Resource):
@console_ns.doc(params=query_params_from_model(RosterListQuery))
@console_ns.response(200, "Agent roster list", console_ns.models[AgentRosterListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
return _agent_roster_service().list_roster_agents(
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
return dump_response(
AgentRosterListResponse,
_agent_roster_service().list_roster_agents(
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
),
)
@console_ns.expect(console_ns.models[RosterAgentCreatePayload.__name__])
@console_ns.response(201, "Agent created", console_ns.models[AgentRosterResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str):
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
service = _agent_roster_service()
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
return service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id), 201
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
return dump_response(
AgentRosterResponse,
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
), 201
@console_ns.route("/agents/invite-options")
class AgentInviteOptionsApi(Resource):
@console_ns.doc(params=query_params_from_model(AgentInviteOptionsQuery))
@console_ns.response(200, "Agent invite options", console_ns.models[AgentInviteOptionsResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
return _agent_roster_service().list_invite_options(
tenant_id=tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
app_id=query.app_id,
return dump_response(
AgentInviteOptionsResponse,
_agent_roster_service().list_invite_options(
tenant_id=tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
app_id=query.app_id,
),
)
@console_ns.route("/agents/<uuid:agent_id>")
class AgentRosterDetailApi(Resource):
@console_ns.response(200, "Agent detail", console_ns.models[AgentRosterResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
return _agent_roster_service().update_roster_agent(
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
return dump_response(
AgentRosterResponse,
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
)
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
@console_ns.response(200, "Agent updated", console_ns.models[AgentRosterResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
@with_current_user_id
@with_current_tenant_id
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
return dump_response(
AgentRosterResponse,
_agent_roster_service().update_roster_agent(
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
),
)
@console_ns.response(204, "Agent archived")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user_id
@with_current_tenant_id
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
return "", 204
@console_ns.route("/agents/<uuid:agent_id>/versions")
class AgentRosterVersionsApi(Resource):
@console_ns.response(200, "Agent versions", console_ns.models[AgentConfigSnapshotListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
return dump_response(
AgentConfigSnapshotListResponse,
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
)
@console_ns.route("/agents/<uuid:agent_id>/versions/<uuid:version_id>")
class AgentRosterVersionDetailApi(Resource):
@console_ns.response(200, "Agent version detail", console_ns.models[AgentConfigSnapshotDetailResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID, version_id: UUID):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_agent_version_detail(
tenant_id=tenant_id,
agent_id=str(agent_id),
version_id=str(version_id),
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
return dump_response(
AgentConfigSnapshotDetailResponse,
_agent_roster_service().get_agent_version_detail(
tenant_id=tenant_id,
agent_id=str(agent_id),
version_id=str(version_id),
),
)

View File

@ -9,18 +9,25 @@ from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.common.schema import register_response_schema_models
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.helper import dump_response, to_timestamp
from libs.login import login_required
from models import Account
from models.dataset import Dataset
from models.enums import ApiTokenType
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
from .wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
class ApiKeyItem(ResponseModel):
@ -40,7 +47,7 @@ class ApiKeyList(ResponseModel):
data: list[ApiKeyItem]
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList)
def _get_resource(resource_id, tenant_id, resource_model):
@ -64,10 +71,11 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
def get(self, resource_id):
def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]:
return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id))
def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList:
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars(
@ -75,13 +83,14 @@ class BaseApiKeyListResource(Resource):
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
return ApiKeyList.model_validate({"data": keys}, from_attributes=True)
@edit_permission_required
def post(self, resource_id):
def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]:
return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201
def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken:
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count: int = (
db.session.scalar(
@ -108,7 +117,7 @@ class BaseApiKeyListResource(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
return api_token
class BaseApiKeyResource(Resource):
@ -118,9 +127,20 @@ class BaseApiKeyResource(Resource):
resource_model: type | None = None
resource_id_field: str | None = None
def delete(self, resource_id: str, api_key_id: str):
def delete(
self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account
) -> tuple[str, int]:
self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user)
return "", 204
def _delete_api_key(
self,
resource_id: str,
api_key_id: str,
current_tenant_id: str,
current_user: Account,
) -> None:
assert self.resource_id_field is not None, "resource_id_field must be set"
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
if not current_user.is_admin_or_owner:
@ -147,8 +167,6 @@ class BaseApiKeyResource(Resource):
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
db.session.commit()
return "", 204
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
class AppApiKeyListResource(BaseApiKeyListResource):
@ -156,18 +174,21 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
"""Get all API keys for an app"""
return super().get(resource_id)
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
@console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id: UUID):
@with_current_tenant_id
@edit_permission_required
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
"""Create a new API key for an app"""
return super().post(resource_id)
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
resource_type = ApiTokenType.APP
resource_model = App
@ -181,9 +202,14 @@ class AppApiKeyResource(BaseApiKeyResource):
@console_ns.doc(description="Delete an API key for an app")
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id: UUID, api_key_id: UUID):
@with_current_user
@with_current_tenant_id
def delete(
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
) -> tuple[str, int]:
"""Delete an API key for an app"""
return super().delete(str(resource_id), str(api_key_id))
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
return "", 204
resource_type = ApiTokenType.APP
resource_model = App
@ -196,18 +222,21 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
"""Get all API keys for a dataset"""
return super().get(resource_id)
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
@console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id: UUID):
@with_current_tenant_id
@edit_permission_required
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
"""Create a new API key for a dataset"""
return super().post(resource_id)
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
resource_type = ApiTokenType.DATASET
resource_model = Dataset
@ -221,9 +250,14 @@ class DatasetApiKeyResource(BaseApiKeyResource):
@console_ns.doc(description="Delete an API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id: UUID, api_key_id: UUID):
@with_current_user
@with_current_tenant_id
def delete(
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
) -> tuple[str, int]:
"""Delete an API key for a dataset"""
return super().delete(str(resource_id), str(api_key_id))
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
return "", 204
resource_type = ApiTokenType.DATASET
resource_model = Dataset

View File

@ -8,7 +8,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from libs.helper import uuid_value
from libs.login import login_required
from models.model import AppMode
from models.model import App, AppMode
from services.agent_service import AgentService
@ -39,7 +39,7 @@ class AgentLogApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
def get(self, app_model: App):
"""Get agent logs"""
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True))

View File

@ -0,0 +1,59 @@
"""Agent App access & sharing endpoints (read-only workflow references).
An Agent App is backed by a roster Agent that workflow Agent nodes may also
reference. This exposes the read-only "Workflow access" surface from the PRD:
which workflow apps use this Agent, without leaking the workflows' internals.
"""
from flask_restx import Resource
from pydantic import Field
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import login_required
from models.model import App, AppMode
from services.agent.roster_service import AgentRosterService
class AgentReferencingWorkflowResponse(ResponseModel):
app_id: str
app_name: str
app_mode: str
workflow_id: str
node_ids: list[str] = Field(default_factory=list)
class AgentReferencingWorkflowsResponse(ResponseModel):
data: list[AgentReferencingWorkflowResponse] = Field(default_factory=list)
register_response_schema_models(console_ns, AgentReferencingWorkflowsResponse)
@console_ns.route("/apps/<uuid:app_id>/agent-referencing-workflows")
class AgentAppReferencingWorkflowsResource(Resource):
@console_ns.doc("list_agent_app_referencing_workflows")
@console_ns.doc(description="List workflow apps that reference this Agent App's bound Agent (read-only)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
"Referencing workflows listed successfully",
console_ns.models[AgentReferencingWorkflowsResponse.__name__],
)
@console_ns.response(404, "App not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
tenant_id=tenant_id, app_id=app_model.id
)
return AgentReferencingWorkflowsResponse(
data=[AgentReferencingWorkflowResponse.model_validate(workflow) for workflow in workflows]
).model_dump(mode="json")

View File

@ -0,0 +1,93 @@
"""Agent App presentation-feature configuration endpoint.
The new Agent App type keeps model / prompt / tools in its bound Agent Soul, so
the legacy ``/model-config`` surface (which writes model, prompt and agent tool
config) is the wrong place to configure its app-level presentation features.
This endpoint exposes only the PRD "Misc Legacy" feature subset — conversation
opener, follow-up suggestions, citations, content moderation and speech — and
persists them onto the app's ``app_model_config`` without touching anything the
Soul owns.
"""
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from events.app_event import app_model_config_was_updated
from libs.helper import dump_response
from libs.login import login_required
from models import Account
from models.agent_config_entities import (
AgentFeatureToggleConfig,
AgentSensitiveWordAvoidanceFeatureConfig,
AgentSuggestedQuestionsAfterAnswerFeatureConfig,
AgentTextToSpeechFeatureConfig,
)
from models.model import App, AppMode
from services.agent_app_feature_service import AgentAppFeatureConfigService
class AgentAppFeaturesPayload(BaseModel):
"""Presentation features configurable on an Agent App.
All fields are optional; an omitted field is reset to its disabled/empty
default (the config form sends the full desired feature state on save).
"""
opening_statement: str | None = Field(default=None, description="Conversation opener shown before the first turn")
suggested_questions: list[str] | None = Field(
default=None, description="Preset questions shown alongside the opener"
)
suggested_questions_after_answer: AgentSuggestedQuestionsAfterAnswerFeatureConfig | None = Field(
default=None, description="Follow-up suggestions config, e.g. {'enabled': true}"
)
speech_to_text: AgentFeatureToggleConfig | None = Field(default=None, description="Speech-to-text config")
text_to_speech: AgentTextToSpeechFeatureConfig | None = Field(default=None, description="Text-to-speech config")
retriever_resource: AgentFeatureToggleConfig | None = Field(
default=None, description="Citations / attributions config, e.g. {'enabled': true}"
)
sensitive_word_avoidance: AgentSensitiveWordAvoidanceFeatureConfig | None = Field(
default=None, description="Content moderation config"
)
register_schema_models(console_ns, AgentAppFeaturesPayload)
register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/apps/<uuid:app_id>/agent-features")
class AgentAppFeatureConfigResource(Resource):
@console_ns.doc("update_agent_app_features")
@console_ns.doc(description="Update an Agent App's presentation features (opener, follow-up, citations, ...)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AgentAppFeaturesPayload.__name__])
@console_ns.response(200, "Features updated successfully", console_ns.models[SimpleResultResponse.__name__])
@console_ns.response(400, "Invalid configuration")
@console_ns.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
@with_current_user
def post(self, current_user: Account, app_model: App):
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
new_app_model_config = AgentAppFeatureConfigService.update_features(
app_model=app_model,
account=current_user,
config=args.model_dump(exclude_none=True),
)
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
return dump_response(SimpleResultResponse, {"result": "success"})

View File

@ -0,0 +1,319 @@
"""Agent App sandbox file-system inspector (read-only).
Exposes the PRD "rc1-like sandbox file system, downloadable not editable" view
for an Agent App conversation: list a directory, preview a file, or download a
file from the conversation's shell-layer workspace. The API never touches
shellctl directly — it resolves the conversation's sandbox ``session_id`` from
the stored session snapshot and proxies to the agent backend's read-only
workspace endpoints.
"""
from typing import Literal
from uuid import UUID
from flask import Response
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
from clients.agent_backend.workspace_files_client import WorkspaceDownloadResult
from controllers.common.schema import (
query_params_from_model,
query_params_from_request,
register_response_schema_models,
)
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.model import App, AppMode
from services.agent_app_workspace_service import (
AgentAppWorkspaceService,
AgentWorkspaceInspectorError,
WorkflowAgentWorkspaceService,
)
class _WorkspaceFileDownloadField(fields.Raw):
__schema_type__ = "string"
__schema_format__ = "binary"
class AgentWorkspaceListQuery(BaseModel):
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
class AgentWorkspaceFileQuery(BaseModel):
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
class WorkflowAgentWorkspaceListQuery(BaseModel):
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
node_execution_id: str | None = Field(
default=None,
description=(
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
),
)
class WorkflowAgentWorkspaceFileQuery(BaseModel):
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
node_execution_id: str | None = Field(
default=None,
description=(
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
),
)
class WorkspaceFileEntryResponse(ResponseModel):
name: str
type: Literal["file", "dir", "symlink"]
size: int
mtime: int
class WorkspaceListResponse(ResponseModel):
path: str
entries: list[WorkspaceFileEntryResponse] = Field(default_factory=list)
truncated: bool = False
class WorkspacePreviewResponse(ResponseModel):
path: str
size: int
truncated: bool
binary: bool
text: str | None = None
register_response_schema_models(console_ns, WorkspaceListResponse)
register_response_schema_models(console_ns, WorkspacePreviewResponse)
def _handle(exc: Exception) -> tuple[dict[str, object], int]:
if isinstance(exc, AgentWorkspaceInspectorError):
return {"code": exc.code, "message": exc.message}, exc.status_code
if isinstance(exc, AgentBackendHTTPError):
detail = exc.detail
if isinstance(detail, dict):
return {
"code": detail.get("code", "agent_backend_error"),
"message": detail.get("message", str(exc)),
}, exc.status_code
return {"code": "agent_backend_error", "message": str(detail)}, exc.status_code
if isinstance(exc, AgentBackendTransportError):
return {"code": "agent_backend_unreachable", "message": str(exc)}, 502
raise exc
def _download_response(result: WorkspaceDownloadResult) -> Response | tuple[dict[str, object], int]:
if result.truncated:
return {
"code": "workspace_file_too_large",
"message": (
"file exceeds the workspace download limit; use preview for partial text or download a smaller file"
),
"size": result.size,
}, 413
filename = result.path.rsplit("/", 1)[-1] or "download"
return Response(
result.content,
mimetype="application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(len(result.content)),
"X-Workspace-File-Size": str(result.size),
},
)
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files")
class AgentAppWorkspaceListResource(Resource):
@console_ns.doc("list_agent_app_workspace_files")
@console_ns.doc(description="List a directory in an Agent App conversation's sandbox workspace (read-only)")
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceListQuery)})
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(AgentWorkspaceListQuery)
try:
result = AgentAppWorkspaceService().list_files(
tenant_id=tenant_id,
app_id=app_model.id,
conversation_id=query.conversation_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return result.model_dump()
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/preview")
class AgentAppWorkspacePreviewResource(Resource):
@console_ns.doc("preview_agent_app_workspace_file")
@console_ns.doc(description="Preview a text/binary file in an Agent App conversation's sandbox workspace")
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(AgentWorkspaceFileQuery)
try:
result = AgentAppWorkspaceService().preview(
tenant_id=tenant_id,
app_id=app_model.id,
conversation_id=query.conversation_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return result.model_dump()
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/download")
class AgentAppWorkspaceDownloadResource(Resource):
@console_ns.doc("download_agent_app_workspace_file")
@console_ns.doc(description="Download a file from an Agent App conversation's sandbox workspace (read-only)")
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
@console_ns.doc(produces=["application/octet-stream"])
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
@console_ns.response(413, "File exceeds the workspace download limit")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(AgentWorkspaceFileQuery)
try:
result = AgentAppWorkspaceService().download(
tenant_id=tenant_id,
app_id=app_model.id,
conversation_id=query.conversation_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return _download_response(result)
@console_ns.route(
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files"
)
class WorkflowAgentWorkspaceListResource(Resource):
@console_ns.doc("list_workflow_agent_workspace_files")
@console_ns.doc(description="List a directory in a Workflow Agent node's sandbox workspace (read-only)")
@console_ns.doc(
params={
"app_id": "Application ID",
"workflow_run_id": "Workflow run ID",
"node_id": "Workflow Agent node ID",
**query_params_from_model(WorkflowAgentWorkspaceListQuery),
}
)
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(WorkflowAgentWorkspaceListQuery)
try:
result = WorkflowAgentWorkspaceService().list_files(
tenant_id=tenant_id,
app_id=app_model.id,
workflow_run_id=str(workflow_run_id),
node_id=node_id,
node_execution_id=query.node_execution_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return result.model_dump()
@console_ns.route(
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/preview"
)
class WorkflowAgentWorkspacePreviewResource(Resource):
@console_ns.doc("preview_workflow_agent_workspace_file")
@console_ns.doc(description="Preview a text/binary file in a Workflow Agent node's sandbox workspace")
@console_ns.doc(
params={
"app_id": "Application ID",
"workflow_run_id": "Workflow run ID",
"node_id": "Workflow Agent node ID",
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
}
)
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
try:
result = WorkflowAgentWorkspaceService().preview(
tenant_id=tenant_id,
app_id=app_model.id,
workflow_run_id=str(workflow_run_id),
node_id=node_id,
node_execution_id=query.node_execution_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return result.model_dump()
@console_ns.route(
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/download"
)
class WorkflowAgentWorkspaceDownloadResource(Resource):
@console_ns.doc("download_workflow_agent_workspace_file")
@console_ns.doc(description="Download a file from a Workflow Agent node's sandbox workspace (read-only)")
@console_ns.doc(
params={
"app_id": "Application ID",
"workflow_run_id": "Workflow run ID",
"node_id": "Workflow Agent node ID",
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
}
)
@console_ns.doc(produces=["application/octet-stream"])
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
@console_ns.response(413, "File exceeds the workspace download limit")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
_, tenant_id = current_account_with_tenant()
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
try:
result = WorkflowAgentWorkspaceService().download(
tenant_id=tenant_id,
app_id=app_model.id,
workflow_run_id=str(workflow_run_id),
node_id=node_id,
node_execution_id=query.node_execution_id,
path=query.path,
)
except Exception as exc: # normalized to an HTTP response below
return _handle(exc)
return _download_response(result)

View File

@ -16,7 +16,7 @@ from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
from controllers.common.helpers import FileInfo
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.app.wraps import get_app_model, with_session
from controllers.console.workspace.models import LoadBalancingPayload
from controllers.console.wraps import (
account_initialization_required,
@ -25,8 +25,10 @@ from controllers.console.wraps import (
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
with_current_user_id,
)
from core.db.session_factory import session_factory
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -35,8 +37,8 @@ from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.enums import WorkflowExecutionStatus
from libs.helper import build_icon_url, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from libs.login import login_required
from models import Account, App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_service import AppListParams, AppService, CreateAppParams
@ -56,7 +58,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"]
register_enum_models(console_ns, IconType)
@ -67,7 +69,7 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
@ -116,7 +118,9 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
mode: Literal["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"] = Field(
..., description="App mode"
)
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -394,6 +398,8 @@ class AppDetailWithSite(AppDetail):
max_active_requests: int | None = None
deleted_tools: list[DeletedTool] = Field(default_factory=list)
site: Site | None = None
# For Agent App type: the roster Agent backing this app (None otherwise).
bound_agent_id: str | None = None
@computed_field(return_type=str | None) # type: ignore
@property
@ -468,10 +474,11 @@ class AppListApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
@with_session(write=False)
@with_current_user_id
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
params = AppListParams(
page=args.page,
@ -484,7 +491,7 @@ class AppListApi(Resource):
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
@ -505,7 +512,7 @@ class AppListApi(Resource):
draft_trigger_app_ids: set[str] = set()
if workflow_capable_app_ids:
draft_workflows = (
db.session.execute(
session.execute(
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
@ -544,9 +551,10 @@ class AppListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload)
params = CreateAppParams(
name=args.name,
@ -574,7 +582,7 @@ class AppApi(Resource):
@account_initialization_required
@enterprise_license_required
@get_app_model(mode=None)
def get(self, app_model):
def get(self, app_model: App):
"""Get app detail"""
app_service = AppService()
@ -582,7 +590,7 @@ class AppApi(Resource):
if FeatureService.get_system_features().webapp_auth.enabled:
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
app_model.access_mode = app_setting.access_mode # type: ignore[attr-defined]
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
@ -599,7 +607,7 @@ class AppApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def put(self, app_model):
def put(self, app_model: App):
"""Update app"""
args = UpdateAppPayload.model_validate(console_ns.payload)
@ -628,7 +636,7 @@ class AppApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_model):
def delete(self, app_model: App):
"""Delete app"""
app_service = AppService()
app_service.delete_app(app_model)
@ -649,11 +657,10 @@ class AppCopyApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
@with_current_user
def post(self, current_user: Account, app_model: App):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine, expire_on_commit=False) as session:
@ -710,7 +717,7 @@ class AppExportApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_model):
def get(self, app_model: App):
"""Export app"""
args = AppExportQuery.model_validate(request.args.to_dict(flat=True))
@ -732,7 +739,8 @@ class AppPublishToCreatorsPlatformApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
@with_current_user_id
def post(self, current_user_id: str, app_model: App):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
@ -740,13 +748,11 @@ class AppPublishToCreatorsPlatformApi(Resource):
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
redirect_url = get_redirect_url(current_user_id, claim_code)
return {"redirect_url": redirect_url}
@ -763,7 +769,7 @@ class AppNameApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
def post(self, app_model: App):
args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService()
@ -785,7 +791,7 @@ class AppIconApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
def post(self, app_model: App):
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
@ -812,7 +818,7 @@ class AppSiteStatus(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
def post(self, app_model: App):
args = AppSiteStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
@ -834,7 +840,7 @@ class AppApiStatus(Resource):
@is_admin_or_owner_required
@account_initialization_required
@get_app_model(mode=None)
def post(self, app_model):
def post(self, app_model: App):
args = AppApiStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
@ -852,11 +858,11 @@ class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_session
@get_app_model
def get(self, app_model):
def get(self, session: Session, app_model: App):
"""Get app trace"""
with session_factory.create_session() as session:
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
return app_trace_config
@ -875,7 +881,7 @@ class AppTraceApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model
def post(self, app_model):
def post(self, app_model: App):
# add app trace
args = AppTracePayload.model_validate(console_ns.payload)

View File

@ -9,9 +9,11 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.model import App
from services.app_dsl_service import AppDslService, Import
from services.enterprise.enterprise_service import EnterpriseService
@ -48,9 +50,9 @@ class AppImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@with_current_user
def post(self, current_user: Account):
# Check user role first
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
# AppDslService performs internal commits for some creation paths, so use a plain
@ -97,10 +99,9 @@ class AppImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
@with_current_user
def post(self, current_user: Account, import_id: str):
# Check user role first
current_user, _ = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Confirm import

View File

@ -70,7 +70,7 @@ class ChatMessageAudioApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model):
def post(self, app_model: App):
file = request.files["file"]
try:
@ -171,7 +171,7 @@ class TextModesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
def get(self, app_model: App):
try:
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True))

View File

@ -4,7 +4,7 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.fields import SimpleResultResponse
@ -19,7 +19,12 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user_id,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -33,7 +38,7 @@ from libs import helper
from libs.helper import uuid_value
from libs.login import current_user, login_required
from models import Account
from models.model import AppMode
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
@ -41,9 +46,24 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
def _resolve_debugger_chat_streaming(
*, app_mode: AppMode, response_mode: str, response_mode_provided: bool = True
) -> bool:
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
if app_mode != AppMode.AGENT:
return response_mode != "blocking"
if response_mode_provided and response_mode == "blocking":
raise BadRequest("Agent App only supports streaming response mode.")
return True
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
model_config_data: dict[str, Any] = Field(..., alias="model_config")
# Agent Apps (AppMode.AGENT) derive their model + prompt from the bound Agent
# Soul, so no override ``model_config`` is sent; chat / agent-chat / completion
# debugging still pass it. Optional here, required in practice by those modes
# downstream when their config is built from args.
model_config_data: dict[str, Any] = Field(default_factory=dict, alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
@ -84,7 +104,7 @@ class CompletionMessageApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
def post(self, app_model: App):
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
@ -131,14 +151,13 @@ class CompletionMessageStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model, task_id: str):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@with_current_user_id
def post(self, current_user_id: str, app_model: App, task_id: str):
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
user_id=current_user_id,
app_mode=AppMode.value_of(app_model.mode),
)
@ -157,13 +176,20 @@ class ChatMessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
@edit_permission_required
def post(self, app_model):
args_model = ChatMessagePayload.model_validate(console_ns.payload)
def post(self, app_model: App):
raw_payload = console_ns.payload or {}
args_model = ChatMessagePayload.model_validate(raw_payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
streaming = args_model.response_mode != "blocking"
streaming = _resolve_debugger_chat_streaming(
app_mode=AppMode.value_of(app_model.mode),
response_mode=args_model.response_mode,
response_mode_provided=isinstance(raw_payload, dict) and "response_mode" in raw_payload,
)
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
args["response_mode"] = "streaming"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)
@ -211,15 +237,14 @@ class ChatMessageStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model, task_id: str):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user_id
def post(self, current_user_id: str, app_model: App, task_id: str):
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
user_id=current_user_id,
app_mode=AppMode.value_of(app_model.mode),
)

View File

@ -12,7 +12,12 @@ from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
@ -31,9 +36,10 @@ from fields.conversation_fields import (
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
)
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from models.account import Account
from models.model import App, AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
@ -93,8 +99,8 @@ class CompletionConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, app_model: App):
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
query = sa.select(Conversation).where(
@ -134,7 +140,7 @@ class CompletionConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
.group_by(Conversation.id)
)
elif args.annotation_status == "not_annotated":
query = (
@ -165,10 +171,11 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def get(self, app_model, conversation_id: UUID):
@with_current_user
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
return ConversationMessageDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_completion_conversation")
@ -182,8 +189,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
try:
@ -205,10 +212,10 @@ class ChatConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, app_model: App):
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
subquery = (
@ -272,7 +279,7 @@ class ChatConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
.group_by(Conversation.id)
)
case "not_annotated":
query = (
@ -316,12 +323,13 @@ class ChatConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@edit_permission_required
def get(self, app_model, conversation_id: UUID):
@with_current_user
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
return ConversationDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_chat_conversation")
@ -332,11 +340,11 @@ class ChatConversationDetailApi(Resource):
@console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
try:
@ -347,8 +355,7 @@ class ChatConversationDetailApi(Resource):
return "", 204
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
def _get_conversation(current_user: Account, app_model, conversation_id):
conversation = db.session.scalar(
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
)

View File

@ -19,7 +19,7 @@ from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
from models.model import App, AppMode
class ConversationVariablesQuery(BaseModel):
@ -94,7 +94,7 @@ class ConversationVariablesApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
def get(self, app_model):
def get(self, app_model: App):
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True))
stmt = (

View File

@ -1,7 +1,9 @@
from collections.abc import Sequence
from typing import Literal
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
@ -11,7 +13,8 @@ from controllers.console.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.app.wraps import with_session
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
@ -19,11 +22,11 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import App
from services.workflow_generator_service import WorkflowGeneratorService
from services.workflow_service import WorkflowService
@ -41,6 +44,24 @@ class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
class WorkflowGeneratePayload(BaseModel):
"""Payload for the cmd+k `/create` and `/refine` workflow generator endpoint.
See ``services/workflow_generator_service.py`` for behaviour. Errors are
surfaced through the same envelope as ``/rule-generate`` so the frontend
can reuse its existing handler.
"""
mode: Literal["workflow", "advanced-chat"] = Field(..., description="Target app mode for the generated graph")
instruction: str = Field(..., description="Natural-language workflow description")
ideal_output: str = Field(default="", description="Optional sample output for grounding")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
current_graph: dict | None = Field(
default=None,
description="Existing draft graph to refine (cmd+k `/refine`); omit for create-from-scratch",
)
register_enum_models(console_ns, LLMMode)
register_schema_models(
console_ns,
@ -49,6 +70,7 @@ register_schema_models(
RuleStructuredOutputPayload,
InstructionGeneratePayload,
InstructionTemplatePayload,
WorkflowGeneratePayload,
ModelConfig,
)
@ -64,9 +86,9 @@ class RuleGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
args = RuleGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
@ -93,9 +115,9 @@ class RuleCodeGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
code_result = LLMGenerator.generate_code(
@ -125,9 +147,9 @@ class RuleStructuredOutputGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
structured_output = LLMGenerator.generate_structured_output(
@ -157,9 +179,10 @@ class InstructionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
@with_session(write=False)
def post(self, session: Session, current_tenant_id: str):
args = InstructionGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args.language)), None
@ -168,10 +191,10 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.get(App, args.flow_id)
app = session.get(App, args.flow_id)
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)
if not workflow:
return {"error": f"workflow {args.flow_id} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]
@ -263,3 +286,56 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args.type}")
@console_ns.route("/workflow-generate")
class WorkflowGenerateApi(Resource):
"""Generate a Workflow / Chatflow draft graph from a natural-language description.
Triggered by the cmd+k `/create` slash command. Returns a graph payload
shaped exactly like ``WorkflowService.sync_draft_workflow``'s input, so the
frontend can hand it straight to ``/apps/{id}/workflows/draft``.
"""
@console_ns.doc("generate_workflow_graph")
@console_ns.doc(description="Generate a Dify workflow graph from natural language")
@console_ns.expect(console_ns.models[WorkflowGeneratePayload.__name__])
@console_ns.response(200, "Workflow graph generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str):
args = WorkflowGeneratePayload.model_validate(console_ns.payload)
# Reject obviously-empty instructions at the boundary — Pydantic only
# validates ``instruction`` is a str, but a whitespace-only string
# would still hit the LLM and waste a planner+builder roundtrip on a
# response that the postprocess validator would reject anyway.
if not args.instruction.strip():
return {
"error": "Instruction is required",
"errors": [{"code": "EMPTY_INSTRUCTION", "detail": "Instruction is required"}],
}, 400
try:
result = WorkflowGeneratorService.generate_workflow_graph(
tenant_id=current_tenant_id,
mode=args.mode,
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
current_graph=args.current_graph,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return result

View File

@ -11,13 +11,18 @@ from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
from models.model import App, AppMCPServer
class MCPServerCreatePayload(BaseModel):
@ -73,7 +78,7 @@ class AppMCPServerController(Resource):
@account_initialization_required
@setup_required
@get_app_model
def get(self, app_model):
def get(self, app_model: App):
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
if server is None:
return {}
@ -92,8 +97,8 @@ class AppMCPServerController(Resource):
@login_required
@setup_required
@edit_permission_required
def post(self, app_model):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, app_model: App):
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
description = payload.description
@ -127,7 +132,7 @@ class AppMCPServerController(Resource):
@setup_required
@account_initialization_required
@edit_permission_required
def put(self, app_model):
def put(self, app_model: App):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.get(AppMCPServer, payload.id)
if not server:
@ -163,8 +168,8 @@ class AppMCPServerRefreshController(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, server_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, server_id: UUID):
server = db.session.scalar(
select(AppMCPServer)
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)

View File

@ -25,6 +25,7 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
@ -43,9 +44,10 @@ from fields.conversation_fields import (
from graphon.model_runtime.errors.invoke import InvokeError
from libs.helper import to_timestamp, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService, attach_message_extra_contents
@ -178,9 +180,9 @@ class ChatMessageListApi(Resource):
@login_required
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@edit_permission_required
def get(self, app_model):
def get(self, app_model: App):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = db.session.scalar(
@ -257,9 +259,8 @@ class MessageFeedbackApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_model):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, app_model: App):
args = MessageFeedbackPayload.model_validate(console_ns.payload)
message_id = str(args.message_id)
@ -314,7 +315,7 @@ class MessageAnnotationCountApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
def get(self, app_model: App):
count = db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
)
@ -336,9 +337,9 @@ class MessageSuggestedQuestionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id: UUID):
current_user, _ = current_account_with_tenant()
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user
def get(self, current_user: Account, app_model: App, message_id: UUID):
message_id_str = str(message_id)
try:
@ -379,7 +380,7 @@ class MessageFeedbackExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
def get(self, app_model: App):
args = FeedbackExportQuery.model_validate(request.args.to_dict())
# Import the service function
@ -417,7 +418,7 @@ class MessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model, message_id: str):
def get(self, app_model: App, message_id: UUID):
message_id_str = str(message_id)
message = db.session.scalar(

View File

@ -8,15 +8,21 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, AppModelConfig
from libs.login import login_required
from models.model import App, AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@ -52,9 +58,10 @@ class ModelConfigResource(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
@with_current_user_id
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
"""Modify app model config"""
current_user, current_tenant_id = current_account_with_tenant()
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_tenant_id,
@ -64,8 +71,8 @@ class ModelConfigResource(Resource):
new_app_model_config = AppModelConfig(
app_id=app_model.id,
created_by=current_user.id,
updated_by=current_user.id,
created_by=current_user_id,
updated_by=current_user_id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
@ -90,7 +97,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
user_id=current_user_id,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_tenant_id,
@ -130,7 +137,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
user_id=current_user_id,
)
except Exception:
continue
@ -167,7 +174,7 @@ class ModelConfigResource(Resource):
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user.id
app_model.updated_by = current_user_id
app_model.updated_at = naive_utc_now()
db.session.commit()

View File

@ -14,12 +14,15 @@ from controllers.console.wraps import (
edit_permission_required,
is_admin_or_owner_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Site
from models.account import Account
from models.model import App
class AppSiteUpdatePayload(BaseModel):
@ -84,9 +87,9 @@ class AppSite(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model
def post(self, app_model):
@with_current_user
def post(self, current_user: Account, app_model: App):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound
@ -133,8 +136,8 @@ class AppSiteAccessTokenReset(Resource):
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
def post(self, app_model):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, app_model: App):
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:

View File

@ -8,13 +8,15 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import AppMode
from models.account import Account
from models.model import App
class StatisticTimeRangeQuery(BaseModel):
@ -47,9 +49,8 @@ class DailyMessageStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -61,8 +62,12 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
arg_dict: dict[str, object] = {
"tz": account.timezone,
"app_id": app_model.id,
"invoke_from": InvokeFrom.DEBUGGER,
}
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
@ -104,9 +109,8 @@ class DailyConversationStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -118,8 +122,12 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
arg_dict: dict[str, object] = {
"tz": account.timezone,
"app_id": app_model.id,
"invoke_from": InvokeFrom.DEBUGGER,
}
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
@ -160,9 +168,8 @@ class DailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -174,8 +181,12 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
arg_dict: dict[str, object] = {
"tz": account.timezone,
"app_id": app_model.id,
"invoke_from": InvokeFrom.DEBUGGER,
}
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
@ -217,9 +228,8 @@ class DailyTokenCostStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -232,8 +242,12 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
arg_dict: dict[str, object] = {
"tz": account.timezone,
"app_id": app_model.id,
"invoke_from": InvokeFrom.DEBUGGER,
}
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
@ -276,10 +290,9 @@ class AverageSessionInteractionStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model):
account, _ = current_account_with_tenant()
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("c.created_at")
@ -299,8 +312,12 @@ FROM
WHERE
c.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
arg_dict: dict[str, object] = {
"tz": account.timezone,
"app_id": app_model.id,
"invoke_from": InvokeFrom.DEBUGGER,
}
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
@ -353,9 +370,8 @@ class UserSatisfactionRateStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("m.created_at")
@ -371,8 +387,12 @@ LEFT JOIN
WHERE
m.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
arg_dict: dict[str, object] = {
"tz": account.timezone,
"app_id": app_model.id,
"invoke_from": InvokeFrom.DEBUGGER,
}
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
@ -419,9 +439,8 @@ class AverageResponseTimeStatistic(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -433,8 +452,12 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
arg_dict: dict[str, object] = {
"tz": account.timezone,
"app_id": app_model.id,
"invoke_from": InvokeFrom.DEBUGGER,
}
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
@ -476,8 +499,8 @@ class TokensPerSecondStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -492,8 +515,12 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
arg_dict: dict[str, object] = {
"tz": account.timezone,
"app_id": app_model.id,
"invoke_from": InvokeFrom.DEBUGGER,
}
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)

View File

@ -1,7 +1,8 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, TypedDict
from typing import Any, Concatenate, TypedDict
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -82,13 +83,14 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
match value:
case FileSegment():
file = value.value
file.remote_url = file.generate_url()
case ArrayFileSegment():
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
@ -212,7 +214,9 @@ workflow_draft_variable_list_model = console_ns.model(
)
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
def _api_prerequisite[T, **P, R](
f: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -229,8 +233,8 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(*args, **kwargs)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(self, *args, **kwargs)
return wrapper
@ -345,14 +349,15 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, app_model: App, variable_id: str):
def get(self, app_model: App, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
return variable
@ -363,7 +368,7 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def patch(self, app_model: App, variable_id: str):
def patch(self, app_model: App, variable_id: UUID):
# Request payload for file types:
#
# Local File:
@ -390,10 +395,11 @@ class VariableApi(Resource):
)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
new_name = args_model.name
@ -434,14 +440,15 @@ class VariableApi(Resource):
@console_ns.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def delete(self, app_model: App, variable_id: str):
def delete(self, app_model: App, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
draft_var_srv.delete_variable(variable)
db.session.commit()
@ -457,7 +464,7 @@ class VariableResetApi(Resource):
@console_ns.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def put(self, app_model: App, variable_id: str):
def put(self, app_model: App, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -468,10 +475,11 @@ class VariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)

View File

@ -0,0 +1,415 @@
"""Console REST endpoints for the Node Output Inspector (Stage 4 §8 / §10.3).
PRD §Node Output Inspector replaces the consumer-organized Variable Inspector
with a producer-organized view of each node's declared outputs and their
per-run status. This module exposes two parallel sets of three read-only
endpoints — one for ``/workflows/draft/runs/...`` (Composer test runs) and one
for ``/workflows/published/runs/...`` (real App API / webapp / webhook /
schedule / plugin triggers). Both sets share the same service code, the same
response shapes, and the same error codes; the URL is the *only* difference,
so the frontend can pick the right prefix based on which run-detail page the
user is on.
Decision D-1 (published Inspector deferred) was lifted 2026-05-26 — the
``published_run_inspector_not_implemented`` 404 code is therefore no longer
produced.
URLs follow the design doc and reuse the existing
``/apps/<uuid:app_id>/workflows/draft/...`` prefix from
:mod:`controllers.console.app.workflow_draft_variable`. The
``published`` prefix mirrors it shape-for-shape.
"""
from __future__ import annotations
import json
import logging
from collections.abc import Iterator
from uuid import UUID
from flask import Response
from flask_restx import Resource
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from libs.exception import BaseHTTPException
from libs.login import login_required
from models import App, AppMode
from services.workflow import inspector_events
from services.workflow.node_output_inspector_service import (
NodeOutputInspectorError,
NodeOutputInspectorService,
)
logger = logging.getLogger(__name__)
# Heartbeat cadence — every N empty subscribe ticks emit a SSE comment so
# intervening proxies (nginx, ingress) don't reap the idle connection.
# ``inspector_events.subscribe`` ticks at 1s, so 15 → 15s heartbeat.
_HEARTBEAT_EVERY_TICKS = 15
# Hard ceiling on a single stream — if we never see a terminal workflow
# event (engine crashed, redis dropped the message), force-close after this
# many ticks (= seconds).
_STREAM_HARD_TIMEOUT_TICKS = 1800 # 30 min
def _service() -> NodeOutputInspectorService:
"""One-line factory so tests can monkeypatch a stub if needed."""
return NodeOutputInspectorService()
def _serve_snapshot(app_model: App, run_id: UUID) -> dict:
"""Resource-body shared by draft + published snapshot endpoints.
Pulled out so the 6 REST routes don't duplicate the same 6-line try/except
+ ``model_dump`` ritual — the routes shrink to one-liners and the actual
behaviour lives here, where unit tests can hit it without spinning up
Flask request context.
"""
try:
snapshot = _service().snapshot_workflow_run(app_model=app_model, workflow_run_id=str(run_id))
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
return snapshot.model_dump(mode="json")
def _serve_node_detail(app_model: App, run_id: UUID, node_id: str) -> dict:
"""Resource-body shared by draft + published node-detail endpoints."""
try:
view = _service().node_detail(
app_model=app_model,
workflow_run_id=str(run_id),
node_id=node_id,
)
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
return view.model_dump(mode="json")
def _serve_output_preview(app_model: App, run_id: UUID, node_id: str, output_name: str) -> dict:
"""Resource-body shared by draft + published output-preview endpoints."""
try:
preview = _service().output_preview(
app_model=app_model,
workflow_run_id=str(run_id),
node_id=node_id,
output_name=output_name,
)
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
return preview.model_dump(mode="json")
class _InspectorNotFound(BaseHTTPException):
"""404 that preserves the inspector's specific error code.
Without this the response body collapses to a generic ``not_found`` code
and clients lose the ability to distinguish, e.g.,
``workflow_run_not_found`` from ``published_run_inspector_not_implemented``.
"""
code = 404
def __init__(self, error: NodeOutputInspectorError) -> None:
self.error_code = error.code
super().__init__(description=str(error))
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs")
class WorkflowDraftRunNodeOutputsApi(Resource):
"""Whole-run snapshot organized by producer node."""
@console_ns.doc("get_workflow_draft_run_node_outputs")
@console_ns.doc(description="Snapshot of every node's declared outputs for a draft workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return _serve_snapshot(app_model, run_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>")
class WorkflowDraftRunNodeOutputDetailApi(Resource):
"""One node's declared outputs + per-output status."""
@console_ns.doc("get_workflow_draft_run_node_output_detail")
@console_ns.doc(description="One node's declared outputs for a draft workflow run.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
}
)
@console_ns.response(404, "Workflow run / node not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str):
return _serve_node_detail(app_model, run_id, node_id)
@console_ns.route(
"/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>/<string:output_name>/preview"
)
class WorkflowDraftRunNodeOutputPreviewApi(Resource):
"""Full value for one declared output (with signed URL for file refs)."""
@console_ns.doc("get_workflow_draft_run_node_output_preview")
@console_ns.doc(description="Full value for one declared output, including signed download URL for files.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
"output_name": "Declared output name as exposed by Composer",
}
)
@console_ns.response(404, "Workflow run / node / output not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
return _serve_output_preview(app_model, run_id, node_id, output_name)
# ──────────────────────────────────────────────────────────────────────────────
# SSE event stream — shared generator used by draft + published variants
# ──────────────────────────────────────────────────────────────────────────────
def _sse_envelope(event: str, data: dict | str, event_id: int) -> str:
"""Format one SSE record per D-5 ``{event, data, id}`` envelope.
``data`` is JSON-serialized when given as a dict; raw strings are
forwarded unchanged so we can also emit ``:keepalive`` comment lines.
"""
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
return f"event: {event}\nid: {event_id}\ndata: {payload}\n\n"
def _stream_inspector_events(app_model: App, run_id: UUID) -> Iterator[str]:
"""Yield SSE-framed strings for one workflow run.
The stream begins with a full ``snapshot`` event so the client has a
starting state without needing a separate REST GET. Then for every
``node_changed`` message from the pub/sub channel we re-read that node
from DB and push a fresh ``node_changed`` event. When the workflow run
reaches a terminal state we push one final ``workflow_run_completed``
event and close the stream.
Failures inside the loop are caught and surfaced as ``error`` events so
the frontend can show a banner rather than seeing the connection drop
silently. The Inspector never raises across the SSE boundary.
"""
service = _service()
run_id_str = str(run_id)
# Initial snapshot — also flushes a 404 back at the client right away
# if the run is gone (raised before yielding any bytes, so Flask turns it
# into the normal HTTP 404 path).
try:
snapshot = service.snapshot_workflow_run(app_model=app_model, workflow_run_id=run_id_str)
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
event_id = 0
yield _sse_envelope("snapshot", snapshot.model_dump(mode="json"), event_id)
# If the run already finished by the time the client connected, emit
# the terminal envelope synchronously and close — no point subscribing.
# The enum value for partial success is the hyphenated ``partial-succeeded``
# (graphon.enums.WorkflowExecutionStatus), not ``partial_succeeded``.
if snapshot.workflow_run_status.value in {"succeeded", "failed", "stopped", "partial-succeeded"}:
event_id += 1
yield _sse_envelope(
"workflow_run_completed",
{"workflow_run_id": run_id_str, "workflow_run_status": snapshot.workflow_run_status.value},
event_id,
)
return
# Live subscription
ticks_since_heartbeat = 0
total_ticks = 0
for message in inspector_events.subscribe(run_id_str, timeout_seconds=1.0):
total_ticks += 1
if total_ticks > _STREAM_HARD_TIMEOUT_TICKS:
logger.warning(
"Inspector SSE: forcing close after %ds without terminal event for run %s",
_STREAM_HARD_TIMEOUT_TICKS,
run_id_str,
)
return
# Heartbeat sentinel — ``inspector_events.subscribe`` synthesizes a
# ``node_changed`` message with both fields ``None`` on every redis
# timeout. Real ``workflow_completed`` messages keep their kind even
# when status couldn't be resolved (publisher race), so checking kind
# first makes the heartbeat branch safe.
if message.kind == "node_changed" and message.node_id is None and message.status is None:
ticks_since_heartbeat += 1
if ticks_since_heartbeat >= _HEARTBEAT_EVERY_TICKS:
yield ":keepalive\n\n"
ticks_since_heartbeat = 0
continue
ticks_since_heartbeat = 0
if message.kind == "workflow_completed":
event_id += 1
yield _sse_envelope(
"workflow_run_completed",
{"workflow_run_id": run_id_str, "workflow_run_status": message.status or "unknown"},
event_id,
)
return
# node_changed: recompute the node slice from DB
if not message.node_id:
continue
try:
node_view = service.node_detail(
app_model=app_model,
workflow_run_id=run_id_str,
node_id=message.node_id,
)
except NodeOutputInspectorError:
# Node may not appear in the graph yet (race with persistence); skip.
continue
except Exception:
logger.warning(
"Inspector SSE: node_detail failed for run %s node %s",
run_id_str,
message.node_id,
exc_info=True,
)
event_id += 1
yield _sse_envelope(
"error",
{"node_id": message.node_id, "message": "failed to refresh node detail"},
event_id,
)
continue
event_id += 1
yield _sse_envelope("node_changed", node_view.model_dump(mode="json"), event_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/events")
class WorkflowDraftRunNodeOutputEventsApi(Resource):
"""SSE stream of inspector deltas for a draft run."""
@console_ns.doc("stream_workflow_draft_run_node_output_events")
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a draft workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return Response(
_stream_inspector_events(app_model, run_id),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
# ──────────────────────────────────────────────────────────────────────────────
# Published-run endpoints — symmetric to the draft trio above
# ──────────────────────────────────────────────────────────────────────────────
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs")
class WorkflowPublishedRunNodeOutputsApi(Resource):
"""Whole-run snapshot for a *published* workflow run.
Same response shape as the ``/draft/`` variant — frontend can multiplex
based on which page (Composer test-run vs. Run History) is mounted.
"""
@console_ns.doc("get_workflow_published_run_node_outputs")
@console_ns.doc(description="Snapshot of every node's declared outputs for a published workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return _serve_snapshot(app_model, run_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/<string:node_id>")
class WorkflowPublishedRunNodeOutputDetailApi(Resource):
"""One node's declared outputs + per-output status (published run)."""
@console_ns.doc("get_workflow_published_run_node_output_detail")
@console_ns.doc(description="One node's declared outputs for a published workflow run.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
}
)
@console_ns.response(404, "Workflow run / node not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str):
return _serve_node_detail(app_model, run_id, node_id)
@console_ns.route(
"/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>"
"/node-outputs/<string:node_id>/<string:output_name>/preview"
)
class WorkflowPublishedRunNodeOutputPreviewApi(Resource):
"""Full value for one declared output of a published run."""
@console_ns.doc("get_workflow_published_run_node_output_preview")
@console_ns.doc(description="Full value for one declared output of a published run.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
"output_name": "Declared output name as exposed by Composer",
}
)
@console_ns.response(404, "Workflow run / node / output not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
return _serve_output_preview(app_model, run_id, node_id, output_name)
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/events")
class WorkflowPublishedRunNodeOutputEventsApi(Resource):
"""SSE stream of inspector deltas for a published run."""
@console_ns.doc("stream_workflow_published_run_node_output_events")
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a published workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return Response(
_stream_inspector_events(app_model, run_id),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)

View File

@ -189,7 +189,7 @@ class WorkflowRunExportApi(Resource):
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App, run_id: str):
def get(self, app_model: App, run_id: UUID):
tenant_id = str(app_model.tenant_id)
app_id = str(app_model.id)
run_id_str = str(run_id)

View File

@ -6,12 +6,13 @@ from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
@ -46,9 +47,8 @@ class WorkflowDailyRunsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -86,9 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -126,9 +125,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -166,9 +164,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None

View File

@ -1,16 +1,38 @@
"""Controller decorators for console app resources.
`with_session` opens one SQLAlchemy session for a request handler and injects it
as the first argument after `self`. Handlers use a transaction by default so
migrated write paths keep commit/rollback handling; pure read handlers may opt
out with `write=False`. App-loading decorators prefer that injected session when
present, while still supporting existing handlers that have not been migrated
yet and still rely on Flask-SQLAlchemy's scoped `db.session`.
"""
from collections.abc import Callable
from functools import wraps
from typing import overload
from typing import Concatenate, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console.app.error import AppNotFoundError
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models import App, AppMode
def _load_app_model(app_id: str) -> App | None:
def _load_app_model(session: Session, app_id: str) -> App | None:
"""Load the tenant-scoped app row with the request session owned by `with_session`."""
_, current_tenant_id = current_account_with_tenant()
app_model = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
return app_model
def _load_app_model_from_scoped_session(app_id: str) -> App | None:
"""Load the app row for legacy handlers that have not adopted request session injection yet."""
_, current_tenant_id = current_account_with_tenant()
app_model = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
@ -23,6 +45,63 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
@overload
def with_session[T, **P, R](
view: Callable[Concatenate[T, Session, P], R],
*,
write: bool = True,
) -> Callable[Concatenate[T, P], R]: ...
@overload
def with_session[T, **P, R](
view: None = None,
*,
write: bool = True,
) -> Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]: ...
def with_session[T, **P, R](
view: Callable[Concatenate[T, Session, P], R] | None = None,
*,
write: bool = True,
) -> (
Callable[Concatenate[T, P], R] | Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]
):
"""Inject a request-scoped session, using a transaction only for write handlers."""
def decorator(view: Callable[Concatenate[T, Session, P], R]) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if write:
with session_factory.get_session_maker().begin() as session:
return view(self, session, *args, **kwargs)
with session_factory.create_session() as session:
return view(self, session, *args, **kwargs)
return wrapper
if view is None:
return decorator
return decorator(view)
def _get_injected_session(args: tuple[object, ...]) -> Session | None:
"""Return the request session inserted by `with_session`, if this handler has been migrated."""
if len(args) < 2:
return None
candidate = args[1]
if isinstance(candidate, Session):
return candidate
if hasattr(candidate, "scalar") and hasattr(candidate, "commit") and hasattr(candidate, "rollback"):
return cast(Session, candidate)
return None
@overload
def get_app_model[**P, R](
view: Callable[P, R],
@ -44,6 +123,13 @@ def get_app_model[**P, R](
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
"""Inject the App model for handlers that receive an `app_id` path parameter.
New handlers may compose `@with_session` above this decorator so the app row
is loaded through the same request-scoped session used by the controller.
Existing handlers continue to work through `db.session` until migrated.
"""
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
@ -55,7 +141,11 @@ def get_app_model[**P, R](
del kwargs["app_id"]
app_model = _load_app_model(app_id)
session = _get_injected_session(args)
if session is None:
app_model = _load_app_model_from_scoped_session(app_id)
else:
app_model = _load_app_model(session, app_id)
if not app_model:
raise AppNotFoundError()

View File

@ -5,12 +5,12 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_response_schema_models, register_schema_models
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id
class ApiKeyAuthBindingPayload(BaseModel):
@ -42,8 +42,8 @@ class ApiKeyAuthDataSource(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings:
return {
@ -69,9 +69,9 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required
@is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
data = payload.model_dump()
ApiKeyAuthService.validate_api_key_auth_args(data)
@ -89,10 +89,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required
@is_admin_or_owner_required
@console_ns.response(204, "Binding deleted successfully")
def delete(self, binding_id: UUID):
@with_current_tenant_id
def delete(self, current_tenant_id: str, binding_id: UUID):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
return "", 204

View File

@ -32,11 +32,11 @@ from controllers.console.wraps import (
decrypt_password_field,
email_password_login_enabled,
setup_required,
with_current_user,
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import timezone as validate_timezone_string
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
@ -46,6 +46,7 @@ from libs.token import (
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models.account import Account
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
@ -172,9 +173,8 @@ class LoginApi(Resource):
class LogoutApi(Resource):
@setup_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
account = current_user
@with_current_user
def post(self, account: Account):
if isinstance(account, flask_login.AnonymousUserMixin):
response = make_response({"result": "success"})
else:

View File

@ -8,9 +8,9 @@ from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
@ -133,12 +133,10 @@ class OAuthServerUserAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id
def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account):
user_account_id = current_user.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
return jsonable_encoder(
{

View File

@ -8,9 +8,16 @@ from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
@ -32,8 +39,9 @@ class Subscription(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@ -45,8 +53,9 @@ class Invoices(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email, current_tenant_id)
@ -63,9 +72,8 @@ class PartnerTenants(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
@with_current_user
def put(self, current_user: Account, partner_key: str):
try:
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id

View File

@ -3,11 +3,18 @@ from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
from ..wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
class ComplianceDownloadQuery(BaseModel):
@ -29,8 +36,9 @@ class ComplianceApi(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
ip_address = extract_remote_ip(request)

View File

@ -1,41 +1,37 @@
import json
from collections.abc import Generator
from datetime import datetime
from typing import Any, Literal, cast
from uuid import UUID
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, field_serializer
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.fields import SimpleResultResponse, TextContentResponse
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_model
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.entities.knowledge_entities import IndexingEstimate
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import (
integrate_fields,
integrate_icon_fields,
integrate_list_fields,
integrate_notion_info_list_fields,
integrate_page_fields,
integrate_workspace_fields,
)
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
from libs.helper import dump_response, to_timestamp
from libs.login import login_required
from models import Account, DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required
from ..wraps import account_initialization_required, setup_required, with_current_tenant_id, with_current_user
class NotionEstimatePayload(BaseModel):
@ -48,57 +44,80 @@ class NotionEstimatePayload(BaseModel):
class DataSourceNotionListQuery(BaseModel):
dataset_id: str | None = Field(default=None, description="Dataset ID")
credential_id: str = Field(..., description="Credential ID", min_length=1)
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
class DataSourceNotionPreviewQuery(BaseModel):
credential_id: str = Field(..., description="Credential ID", min_length=1)
register_schema_model(console_ns, NotionEstimatePayload)
register_response_schema_models(console_ns, SimpleResultResponse, TextContentResponse)
class DataSourceIntegrateIconResponse(ResponseModel):
type: str | None = None
url: str | None = None
emoji: str | None = None
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
class DataSourceIntegratePageResponse(ResponseModel):
page_name: str
page_id: str
page_icon: DataSourceIntegrateIconResponse | None
parent_id: str
type: str
integrate_page_fields_copy = integrate_page_fields.copy()
integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
integrate_workspace_fields_copy = integrate_workspace_fields.copy()
integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
class DataSourceIntegrateWorkspaceResponse(ResponseModel):
workspace_name: str | None
workspace_id: str | None
workspace_icon: str | None
pages: list[DataSourceIntegratePageResponse]
total: int
integrate_fields_copy = integrate_fields.copy()
integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
integrate_list_fields_copy = integrate_list_fields.copy()
integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
class DataSourceIntegrateResponse(ResponseModel):
id: str | None
provider: str
created_at: datetime | int | None
is_bound: bool
disabled: bool | None
link: str
source_info: DataSourceIntegrateWorkspaceResponse | None
notion_page_fields = {
"page_name": fields.String,
"page_id": fields.String,
"page_icon": fields.Nested(integrate_icon_model, allow_null=True),
"is_bound": fields.Boolean,
"parent_id": fields.String,
"type": fields.String,
}
notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
@field_serializer("created_at")
def serialize_created_at(self, value: datetime | int | None) -> int | None:
return to_timestamp(value)
notion_workspace_fields = {
"workspace_name": fields.String,
"workspace_id": fields.String,
"workspace_icon": fields.String,
"pages": fields.List(fields.Nested(notion_page_model)),
}
notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
integrate_notion_info_list_model = get_or_create_model(
"NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
class DataSourceIntegrateListResponse(ResponseModel):
data: list[DataSourceIntegrateResponse]
class NotionIntegratePageResponse(ResponseModel):
page_name: str
page_id: str
page_icon: DataSourceIntegrateIconResponse | None
parent_id: str | None
type: str
is_bound: bool
class NotionIntegrateWorkspaceResponse(ResponseModel):
workspace_name: str | None
workspace_id: str | None
workspace_icon: str | None
pages: list[NotionIntegratePageResponse]
class NotionIntegrateInfoListResponse(ResponseModel):
notion_info: list[NotionIntegrateWorkspaceResponse]
register_schema_models(console_ns, NotionEstimatePayload)
register_response_schema_models(
console_ns,
DataSourceIntegrateListResponse,
IndexingEstimate,
NotionIntegrateInfoListResponse,
SimpleResultResponse,
TextContentResponse,
)
@ -110,10 +129,9 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
@console_ns.response(200, "Success", console_ns.models[DataSourceIntegrateListResponse.__name__])
@with_current_tenant_id
def get(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
# get workspace data source integrates
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
@ -155,19 +173,21 @@ class DataSourceApi(Resource):
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
}
)
return {"data": integrate_data}, 200
return dump_response(DataSourceIntegrateListResponse, {"data": integrate_data}), 200
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, binding_id, action: Literal["enable", "disable"]):
_, current_tenant_id = current_account_with_tenant()
binding_id = str(binding_id)
@with_current_tenant_id
def patch(
self, current_tenant_id: str, binding_id: UUID, action: Literal["enable", "disable"]
) -> tuple[dict[str, str], int]:
binding_id_str = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
DataSourceOauthBinding.id == binding_id_str, DataSourceOauthBinding.tenant_id == current_tenant_id
)
).scalar_one_or_none()
if data_source_binding is None:
@ -199,15 +219,12 @@ class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_notion_info_list_model)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters = query.datasource_parameters or {}
@console_ns.doc(params=query_params_from_model(DataSourceNotionListQuery))
@console_ns.response(200, "Success", console_ns.models[NotionIntegrateInfoListResponse.__name__])
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str, Any], int]:
query = DataSourceNotionListQuery.model_validate(request.args.to_dict(flat=True))
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
@ -255,7 +272,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters=datasource_parameters,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
)
)
@ -282,22 +299,22 @@ class DataSourceNotionListApi(Resource):
pages.append(page_info)
except Exception as e:
raise e
return {"notion_info": {**workspace_info, "pages": pages}}, 200
notion_info = [{**workspace_info, "pages": pages}] if workspace_info else []
return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200
@console_ns.route(
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource):
@console_ns.route("/notion/pages/<uuid:page_id>/<string:page_type>/preview")
class DataSourceNotionPreviewApi(Resource):
"""Preview one authorized Notion page through the datasource credential."""
@setup_required
@login_required
@account_initialization_required
@console_ns.doc(params=query_params_from_model(DataSourceNotionPreviewQuery))
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
def get(self, page_id: UUID, page_type: str):
_, current_tenant_id = current_account_with_tenant()
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
@with_current_tenant_id
def get(self, current_tenant_id: str, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict(flat=True))
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
@ -320,13 +337,18 @@ class DataSourceNotionApi(Resource):
text_docs = extractor.extract()
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
@console_ns.route("/datasets/notion-indexing-estimate")
class DataSourceNotionIndexingEstimateApi(Resource):
"""Estimate indexing work for selected Notion pages."""
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
def post(self):
_, current_tenant_id = current_account_with_tenant()
@console_ns.response(200, "Success", console_ns.models[IndexingEstimate.__name__])
@with_current_tenant_id
def post(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
# validate args
@ -359,7 +381,7 @@ class DataSourceNotionApi(Resource):
args["doc_form"],
args["doc_language"],
)
return response.model_dump(), 200
return dump_response(IndexingEstimate, response), 200
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
@ -368,7 +390,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID):
def get(self, dataset_id: UUID) -> tuple[dict[str, str], int]:
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -386,7 +408,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id: UUID, document_id: UUID) -> tuple[dict[str, str], int]:
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)

View File

@ -9,7 +9,7 @@ from uuid import UUID
import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource, marshal
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
@ -34,16 +34,18 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.document_fields import (
document_fields,
document_status_fields,
document_with_segments_fields,
DocumentMetadataResponse,
DocumentResponse,
DocumentStatusListResponse,
DocumentStatusResponse,
normalize_enum,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from libs.helper import dump_response, to_timestamp
from libs.login import login_required
from models import Account, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.enums import IndexingStatus, SegmentStatus
from services.dataset_service import DatasetService, DocumentService
@ -69,17 +71,13 @@ from ..wraps import (
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
logger = logging.getLogger(__name__)
def _normalize_enum(value: Any) -> Any:
if isinstance(value, str) or value is None:
return value
return getattr(value, "value", value)
class DatasetResponse(ResponseModel):
id: str
name: str
@ -93,7 +91,7 @@ class DatasetResponse(ResponseModel):
@field_validator("data_source_type", "indexing_technique", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
return normalize_enum(value)
@field_validator("created_at", mode="before")
@classmethod
@ -101,61 +99,10 @@ class DatasetResponse(ResponseModel):
return to_timestamp(value)
class DocumentMetadataResponse(ResponseModel):
id: str
name: str
type: str
value: str | None = None
class DocumentResponse(ResponseModel):
id: str
position: int | None = None
data_source_type: str | None = None
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
data_source_detail_dict: Any = None
dataset_process_rule_id: str | None = None
name: str
created_from: str | None = None
created_by: str | None = None
created_at: int | None = None
tokens: int | None = None
indexing_status: str | None = None
error: str | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
archived: bool | None = None
display_status: str | None = None
word_count: int | None = None
hit_count: int | None = None
doc_form: str | None = None
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
summary_index_status: str | None = None
need_summary: bool | None = None
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("doc_metadata", mode="before")
@classmethod
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
if value is None:
return []
return value
@field_validator("created_at", "disabled_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DocumentWithSegmentsResponse(DocumentResponse):
process_rule_dict: Any = None
completed_segments: int | None = None
total_segments: int | None = None
completed_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
total_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
class DatasetAndDocumentResponse(ResponseModel):
@ -190,6 +137,14 @@ class DocumentDatasetListParam(BaseModel):
fetch_val: str = Field("false", alias="fetch")
class DocumentWithSegmentsListResponse(ResponseModel):
data: list[DocumentWithSegmentsResponse]
has_more: bool
limit: int
total: int
page: int
register_schema_models(
console_ns,
KnowledgeConfig,
@ -200,18 +155,25 @@ register_schema_models(
GenerateSummaryPayload,
DocumentMetadataUpdatePayload,
DocumentBatchDownloadZipPayload,
)
register_response_schema_models(
console_ns,
SimpleResultMessageResponse,
SimpleResultResponse,
UrlResponse,
DatasetResponse,
DocumentMetadataResponse,
DocumentResponse,
DocumentWithSegmentsResponse,
DatasetAndDocumentResponse,
DocumentWithSegmentsListResponse,
)
register_response_schema_models(console_ns, SimpleResultMessageResponse, SimpleResultResponse, UrlResponse)
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
def get_document(
self, dataset_id: str, document_id: str, current_user: Account, current_tenant_id: str
) -> Document:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -231,8 +193,7 @@ class DocumentResource(Resource):
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
current_user, _ = current_account_with_tenant()
def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account) -> Sequence[Document]:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -259,8 +220,8 @@ class GetProcessRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
req_data = request.args
document_id = req_data.get("document_id")
@ -312,12 +273,17 @@ class DatasetDocumentListApi(Resource):
"status": "Filter documents by display status",
}
)
@console_ns.response(200, "Documents retrieved successfully")
@console_ns.response(
200,
"Documents retrieved successfully",
console_ns.models[DocumentWithSegmentsListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
@ -425,18 +391,15 @@ class DatasetDocumentListApi(Resource):
)
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
else:
data = marshal(documents, document_fields)
response = {
"data": data,
"data": documents,
"has_more": len(documents) == limit,
"limit": limit,
"total": paginated_documents.total,
"page": page,
}
return response
return dump_response(DocumentWithSegmentsListResponse, response)
@setup_required
@login_required
@ -445,8 +408,8 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -482,9 +445,7 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
@setup_required
@login_required
@ -522,9 +483,10 @@ class DatasetInitApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
@ -567,9 +529,7 @@ class DatasetInitApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
@ -583,11 +543,12 @@ class DocumentIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
@ -648,10 +609,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, batch: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch)
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
if not documents:
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule
@ -742,12 +704,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
class DocumentBatchIndexingStatusApi(DocumentResource):
@console_ns.response(
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusListResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, batch: str):
@with_current_user
def get(self, current_user: Account, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch)
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
documents_status = []
for document in documents:
completed_segments = (
@ -784,9 +750,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data
documents_status.append(document_dict)
return dump_response(DocumentStatusListResponse, {"data": documents_status})
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
@ -794,21 +759,25 @@ class DocumentIndexingStatusApi(DocumentResource):
@console_ns.doc("get_document_indexing_status")
@console_ns.doc(description="Get document indexing status")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Indexing status retrieved successfully")
@console_ns.response(
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusResponse.__name__]
)
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
completed_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.document_id == document_id_str,
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -817,7 +786,7 @@ class DocumentIndexingStatusApi(DocumentResource):
total_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.document_id == document_id_str,
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -839,7 +808,7 @@ class DocumentIndexingStatusApi(DocumentResource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
return marshal(document_dict, document_status_fields)
return dump_response(DocumentStatusResponse, document_dict)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
@ -860,10 +829,12 @@ class DocumentApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
metadata = request.args.get("metadata", "all")
if metadata not in self.METADATA_CHOICES:
@ -949,7 +920,9 @@ class DocumentApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Document deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
@with_current_user
@with_current_tenant_id
def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -958,7 +931,7 @@ class DocumentApi(DocumentResource):
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
try:
DocumentService.delete_document(document)
@ -979,9 +952,11 @@ class DocumentDownloadApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
# Reuse the shared permission/tenant checks implemented in DocumentResource.
document = self.get_document(str(dataset_id), str(document_id))
document = self.get_document(str(dataset_id), str(document_id), current_user, current_tenant_id)
return {"url": DocumentService.get_document_download_url(document)}
@ -996,12 +971,13 @@ class DocumentBatchDownloadZipApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
def post(self, dataset_id: str):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
"""Stream a ZIP archive containing the requested uploaded documents."""
# Parse and validate request payload.
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
@ -1043,11 +1019,19 @@ class DocumentProcessingApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def patch(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
action: Literal["pause", "resume"],
):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
@ -1091,11 +1075,12 @@ class DocumentMetadataApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id: UUID, document_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def put(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
@ -1140,8 +1125,10 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
@with_current_user
def patch(
self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]
):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -1256,8 +1243,6 @@ class DocumentRetryApi(DocumentResource):
raise NotFound("Dataset not found.")
for document_id in payload.document_ids:
try:
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
@ -1288,9 +1273,9 @@ class DocumentRenameApi(DocumentResource):
@account_initialization_required
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
@with_current_user
def post(self, current_user: Account, dataset_id: UUID, document_id: UUID):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
@ -1304,7 +1289,7 @@ class DocumentRenameApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
return dump_response(DocumentResponse, document)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
@ -1313,9 +1298,9 @@ class WebsiteDocumentSyncApi(DocumentResource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID, document_id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID):
"""sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if not dataset:
@ -1391,7 +1376,8 @@ class DocumentGenerateSummaryApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id: UUID):
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
"""
Generate summary index for specified documents.
@ -1399,7 +1385,6 @@ class DocumentGenerateSummaryApi(Resource):
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# Get dataset
@ -1484,7 +1469,8 @@ class DocumentSummaryStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
@with_current_user
def get(self, current_user: Account, dataset_id: UUID, document_id: UUID):
"""
Get summary index generation status for a document.
@ -1497,7 +1483,6 @@ class DocumentSummaryStatusApi(DocumentResource):
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)

View File

@ -1,11 +1,12 @@
import uuid
from typing import Literal
from typing import cast as type_cast
from uuid import UUID
from flask import request
from flask_restx import Resource, marshal
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import String, cast, func, or_, select
from sqlalchemy import String, case, cast, func, literal, or_, select
from sqlalchemy.dialects.postgresql import JSONB
from werkzeug.exceptions import Forbidden, NotFound
@ -13,7 +14,12 @@ import services
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.common.schema import (
query_params_from_model,
query_params_from_request,
register_response_schema_models,
register_schema_models,
)
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import (
@ -27,6 +33,8 @@ from controllers.console.wraps import (
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
@ -34,30 +42,29 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.base import ResponseModel
from fields.segment_fields import child_chunk_fields, segment_fields
from fields.segment_fields import (
ChildChunkDetailResponse,
ChildChunkListResponse,
ChildChunkResponse,
SegmentDetailResponse,
SegmentResponse,
segment_response_with_summary,
segment_responses_with_summaries,
)
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from libs.helper import dump_response, escape_like_pattern
from libs.login import login_required
from models import Account
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from services.summary_index_service import SummaryIndexService
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@ -67,6 +74,16 @@ class SegmentListQuery(BaseModel):
page: int = Field(default=1, ge=1)
class SegmentIdListQuery(BaseModel):
segment_id: list[str] = Field(default_factory=list, description="Segment IDs")
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
@ -92,13 +109,35 @@ class SegmentBatchImportStatusResponse(ResponseModel):
job_status: str
class ConsoleSegmentListResponse(ResponseModel):
data: list[SegmentResponse]
limit: int
total: int
total_pages: int
page: int
class ChildChunkBatchUpdateResponse(ResponseModel):
data: list[ChildChunkResponse]
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]
class SegmentDocParams:
DATASET_DOCUMENT = {"dataset_id": "Dataset ID", "document_id": "Document ID"}
DATASET_DOCUMENT_ACTION = {**DATASET_DOCUMENT, "action": "Action"}
DATASET_DOCUMENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Segment ID"}
DATASET_DOCUMENT_PARENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Parent segment ID"}
DATASET_DOCUMENT_CHILD_CHUNK = {**DATASET_DOCUMENT_PARENT_SEGMENT, "child_chunk_id": "Child chunk ID"}
register_schema_models(
console_ns,
SegmentListQuery,
SegmentIdListQuery,
ChildChunkListQuery,
SegmentCreatePayload,
SegmentUpdatePayload,
BatchImportPayload,
@ -107,17 +146,30 @@ register_schema_models(
ChildChunkBatchUpdatePayload,
ChildChunkUpdateArgs,
)
register_response_schema_models(console_ns, SegmentBatchImportStatusResponse, SimpleResultResponse)
register_response_schema_models(
console_ns,
SegmentResponse,
ConsoleSegmentListResponse,
SegmentDetailResponse,
ChildChunkDetailResponse,
ChildChunkListResponse,
ChildChunkBatchUpdateResponse,
SegmentBatchImportStatusResponse,
SimpleResultResponse,
)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@console_ns.doc(params=query_params_from_model(SegmentListQuery))
@console_ns.response(200, "Segments retrieved successfully", console_ns.models[ConsoleSegmentListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -134,12 +186,7 @@ class DatasetDocumentSegmentListApi(Resource):
if not document:
raise NotFound("Document not found.")
args = SegmentListQuery.model_validate(
{
**request.args.to_dict(),
"status": request.args.getlist("status"),
}
)
args = query_params_from_request(SegmentListQuery, list_fields=("status",))
page = args.page
limit = min(args.limit, 100)
@ -169,9 +216,17 @@ class DatasetDocumentSegmentListApi(Resource):
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
# Feed the set-returning function a JSON array in every row. Filtering in
# the subquery is not enough because PostgreSQL can still evaluate the
# SRF on scalar JSON before applying the predicate.
keywords_jsonb = cast(DocumentSegment.keywords, JSONB)
keywords_array = case(
(func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb),
else_=cast(literal("[]"), JSONB),
)
keywords_condition = func.array_to_string(
func.array(
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
select(func.jsonb_array_elements_text(keywords_array))
.correlate(DocumentSegment)
.scalar_subquery()
),
@ -197,42 +252,33 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
segment_list = list(segments.items)
segment_ids = [segment.id for segment in segment_list]
summaries: dict[str, str | None] = {}
if segment_ids:
from services.summary_index_service import SummaryIndexService
summary_records = SummaryIndexService.get_segments_summaries(
segment_ids=segment_ids, dataset_id=dataset_id_str
)
# Only include enabled summaries (already filtered by service)
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": segments_with_summary,
"data": segment_responses_with_summaries(segment_list, summaries),
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
"page": page,
}
return response, 200
return dump_response(ConsoleSegmentListResponse, response), 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
@console_ns.response(204, "Segments deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -260,15 +306,24 @@ class DatasetDocumentSegmentListApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
class DatasetDocumentSegmentApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_ACTION)
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def patch(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
action: Literal["enable", "disable"],
):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if not dataset:
@ -313,11 +368,12 @@ class DatasetDocumentSegmentApi(Resource):
SegmentService.update_segments_status(segment_ids, action, dataset, document)
except Exception as e:
raise InvalidActionError(str(e))
return {"result": "success"}, 200
return dump_response(SimpleResultResponse, {"result": "success"}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
class DatasetDocumentSegmentAddApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@setup_required
@login_required
@account_initialization_required
@ -325,9 +381,10 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -364,21 +421,30 @@ class DatasetDocumentSegmentAddApi(Resource):
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
segment = type_cast(DocumentSegment, SegmentService.create_segment(payload_dict, document, dataset))
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
"doc_form": document.doc_form,
}
return dump_response(SegmentDetailResponse, response), 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
class DatasetDocumentSegmentUpdateApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
@with_current_user
@with_current_tenant_id
def patch(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -432,16 +498,24 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
response = {
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
"doc_form": document.doc_form,
}
return dump_response(SegmentDetailResponse, response), 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@console_ns.response(204, "Segment deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def delete(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -487,9 +561,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -515,11 +589,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
try:
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = f"segment_batch_import_{str(job_id)}"
indexing_cache_key = f"segment_batch_import_{job_id}"
# send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay(
str(job_id),
job_id,
upload_file_id,
dataset_id_str,
document_id_str,
@ -528,7 +602,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
)
except Exception as e:
return {"error": str(e)}, 500
return {"job_id": job_id, "job_status": "waiting"}, 200
return dump_response(SegmentBatchImportStatusResponse, {"job_id": job_id, "job_status": "waiting"}), 200
@console_ns.response(200, "Batch import status", console_ns.models[SegmentBatchImportStatusResponse.__name__])
@setup_required
@ -543,11 +617,13 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if cache_result is None:
raise ValueError("The job does not exist.")
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
response = {"job_id": job_id, "job_status": cache_result.decode()}
return dump_response(SegmentBatchImportStatusResponse, response), 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
class ChildChunkAddApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@setup_required
@login_required
@account_initialization_required
@ -555,9 +631,12 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
@with_current_user
@with_current_tenant_id
def post(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -605,14 +684,16 @@ class ChildChunkAddApi(Resource):
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@console_ns.doc(params=query_params_from_model(ChildChunkListQuery))
@console_ns.response(200, "Child chunks retrieved successfully", console_ns.models[ChildChunkListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -634,13 +715,7 @@ class ChildChunkAddApi(Resource):
)
if not segment:
raise NotFound("Segment not found.")
args = SegmentListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
page = args.page
limit = min(args.limit, 100)
@ -649,22 +724,32 @@ class ChildChunkAddApi(Resource):
child_chunks = SegmentService.get_child_chunks(
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
response = {
"data": child_chunks.items,
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}, 200
}
return dump_response(ChildChunkListResponse, response), 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
@console_ns.response(
200,
"Child chunks updated successfully",
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
)
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
@with_current_user
@with_current_tenant_id
def patch(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -699,7 +784,7 @@ class ChildChunkAddApi(Resource):
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
return dump_response(ChildChunkBatchUpdateResponse, {"data": child_chunks}), 200
@console_ns.route(
@ -710,10 +795,19 @@ class ChildChunkUpdateApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.response(204, "Child chunk deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def delete(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
segment_id: UUID,
child_chunk_id: UUID,
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -740,7 +834,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.id == child_chunk_id_str,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
@ -767,10 +861,20 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
@with_current_user
@with_current_tenant_id
def patch(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
segment_id: UUID,
child_chunk_id: UUID,
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -797,7 +901,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.id == child_chunk_id_str,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
@ -819,4 +923,4 @@ class ChildChunkUpdateApi(Resource):
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200

View File

@ -10,7 +10,13 @@ from controllers.common.fields import UsageCountResponse
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from fields.dataset_fields import (
dataset_detail_fields,
dataset_retrieval_model_fields,
@ -24,7 +30,8 @@ from fields.dataset_fields import (
vector_setting_fields,
weighted_score_fields,
)
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
@ -126,9 +133,9 @@ class ExternalApiTemplateListApi(Resource):
@console_ns.response(200, "External API templates retrieved successfully")
@setup_required
@login_required
@with_current_tenant_id
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
def get(self, current_tenant_id: str):
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
@ -147,8 +154,9 @@ class ExternalApiTemplateListApi(Resource):
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(payload.settings)
@ -177,8 +185,8 @@ class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
external_knowledge_api_id_str, current_tenant_id
@ -192,8 +200,9 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def patch(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
external_knowledge_api_id_str = str(external_knowledge_api_id)
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
@ -212,8 +221,9 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(204, "External knowledge API deleted successfully")
def delete(self, external_knowledge_api_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def delete(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
external_knowledge_api_id_str = str(external_knowledge_api_id)
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
@ -232,8 +242,8 @@ class ExternalApiUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
@ -254,9 +264,10 @@ class ExternalDatasetCreateApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant()
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
@ -288,8 +299,8 @@ class ExternalKnowledgeHitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -1,15 +1,12 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from flask_restx import Resource
from pydantic import Field, field_validator
from controllers.common.schema import register_schema_models
from fields.base import ResponseModel
from libs.helper import to_timestamp
from controllers.common.schema import register_response_schema_models, register_schema_models
from fields.hit_testing_fields import HitTestingResponse
from libs.helper import dump_response
from libs.login import login_required
from .. import console_ns
@ -20,86 +17,8 @@ from ..wraps import (
setup_required,
)
class HitTestingDocument(ResponseModel):
id: str | None = None
data_source_type: str | None = None
name: str | None = None
doc_type: str | None = None
doc_metadata: Any | None = None
class HitTestingSegment(ResponseModel):
id: str | None = None
position: int | None = None
document_id: str | None = None
content: str | None = None
sign_content: str | None = None
answer: str | None = None
word_count: int | None = None
tokens: int | None = None
keywords: list[str] = Field(default_factory=list)
index_node_id: str | None = None
index_node_hash: str | None = None
hit_count: int | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
status: str | None = None
created_by: str | None = None
created_at: int | None = None
indexing_at: int | None = None
completed_at: int | None = None
error: str | None = None
stopped_at: int | None = None
document: HitTestingDocument | None = None
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class HitTestingChildChunk(ResponseModel):
id: str | None = None
content: str | None = None
position: int | None = None
score: float | None = None
class HitTestingFile(ResponseModel):
id: str | None = None
name: str | None = None
size: int | None = None
extension: str | None = None
mime_type: str | None = None
source_url: str | None = None
class HitTestingRecord(ResponseModel):
segment: HitTestingSegment | None = None
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
score: float | None = None
tsne_position: Any | None = None
files: list[HitTestingFile] = Field(default_factory=list)
summary: str | None = None
class HitTestingResponse(ResponseModel):
query: str
records: list[HitTestingRecord] = Field(default_factory=list)
register_schema_models(
console_ns,
HitTestingPayload,
HitTestingDocument,
HitTestingSegment,
HitTestingChildChunk,
HitTestingFile,
HitTestingRecord,
HitTestingResponse,
)
register_schema_models(console_ns, HitTestingPayload)
register_response_schema_models(console_ns, HitTestingResponse)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@ -119,12 +38,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id: UUID):
def post(self, dataset_id: UUID) -> dict[str, object]:
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
payload = HitTestingPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args = self.parse_args(console_ns.payload)
self.hit_testing_args_check(args)
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))

View File

@ -1,7 +1,6 @@
import logging
from typing import Any
from typing import Any, cast
from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -19,10 +18,10 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from fields.hit_testing_fields import hit_testing_record_fields
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import current_user
from models.account import Account
from models.dataset import Dataset
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.hit_testing_service import HitTestingService
@ -38,16 +37,6 @@ class HitTestingPayload(BaseModel):
class DatasetsHitTestingBase:
@staticmethod
def _extract_hit_testing_query(query: Any) -> str:
"""Return the query string from the service response shape."""
if isinstance(query, dict):
content = query.get("content")
if isinstance(content, str):
return content
raise ValueError("Invalid hit testing query response")
@staticmethod
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Ensure collection fields match the API schema before response validation."""
@ -63,6 +52,7 @@ class DatasetsHitTestingBase:
segment = normalized_record.get("segment")
if isinstance(segment, dict):
normalized_segment = dict(segment)
normalized_segment.setdefault("sign_content", None)
if normalized_segment.get("keywords") is None:
normalized_segment["keywords"] = []
normalized_record["segment"] = normalized_segment
@ -73,12 +63,15 @@ class DatasetsHitTestingBase:
if normalized_record.get("files") is None:
normalized_record["files"] = []
normalized_record.setdefault("tsne_position", None)
normalized_record.setdefault("summary", None)
normalized_records.append(normalized_record)
return normalized_records
@staticmethod
def get_and_validate_dataset(dataset_id: str):
def get_and_validate_dataset(dataset_id: str) -> Dataset:
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
@ -92,33 +85,35 @@ class DatasetsHitTestingBase:
return dataset
@staticmethod
def hit_testing_args_check(args: dict[str, Any]):
def hit_testing_args_check(args: dict[str, Any]) -> None:
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
def parse_args(payload: dict[str, Any] | None) -> dict[str, Any]:
"""Validate and return hit-testing arguments from an incoming payload."""
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
return hit_testing_payload.model_dump(exclude_none=True)
@staticmethod
def perform_hit_testing(dataset, args):
def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]:
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args.get("query"),
query=cast(str, args.get("query")),
account=current_user,
retrieval_model=args.get("retrieval_model"),
external_retrieval_model=args.get("external_retrieval_model"),
external_retrieval_model=cast(dict[str, Any], args.get("external_retrieval_model")),
attachment_ids=args.get("attachment_ids"),
limit=10,
)
query = response.get("query")
if not isinstance(query, dict) or not isinstance(query.get("content"), str):
raise ValueError("Invalid hit testing query response")
return {
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
marshal(response.get("records", []), hit_testing_record_fields)
),
"query": {"content": query["content"]},
"records": DatasetsHitTestingBase._prepare_hit_testing_records(response.get("records", [])),
}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()

View File

@ -7,14 +7,20 @@ from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
with_current_user,
)
from fields.dataset_fields import (
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
@ -43,8 +49,8 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
@ -80,8 +86,8 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id: UUID, metadata_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
@ -100,8 +106,8 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Metadata deleted successfully")
def delete(self, dataset_id: UUID, metadata_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -137,8 +143,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Action completed successfully")
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -165,8 +171,8 @@ class DocumentMetadataEditApi(Resource):
204,
"Documents metadata updated successfully",
)
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -9,11 +9,18 @@ from configs import dify_config
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from core.plugin.impl.oauth import OAuthHandler
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
@ -66,11 +73,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider_id: str):
tenant_id = current_tenant_id
credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = datasource_provider_id.provider_name
@ -174,9 +180,8 @@ class DatasourceAuth(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -195,15 +200,17 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, user: Account, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
user=user,
)
return {"result": datasources}, 200
@ -216,9 +223,8 @@ class DatasourceAuthDeleteApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
@ -241,9 +247,8 @@ class DatasourceAuthUpdateApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
@ -264,9 +269,8 @@ class DatasourceAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -277,9 +281,8 @@ class DatasourceHardCodeAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -292,9 +295,8 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -310,9 +312,8 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
@ -330,9 +331,8 @@ class DatasourceAuthDefaultApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -352,9 +352,8 @@ class DatasourceUpdateProviderNameApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()

View File

@ -1,13 +1,20 @@
import logging
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.fields import SimpleDataResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.common.schema import (
JsonResponseWithStatus,
query_params_from_model,
register_response_schema_models,
register_schema_models,
)
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@ -16,79 +23,132 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import dump_response
from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
logger: logging.Logger = logging.getLogger(__name__)
class PipelineTemplateListQuery(BaseModel):
type: str = Field(default="built-in", description="Template source: built-in or customized")
language: str = Field(default="en-US", description="Template language")
class PipelineTemplateDetailQuery(BaseModel):
type: str = Field(default="built-in", description="Template source: built-in or customized")
class PipelineTemplateItemResponse(ResponseModel):
id: str
name: str
icon: dict[str, Any]
description: str
position: int
chunk_structure: str
copyright: str | None = None
privacy_policy: str | None = None
class PipelineTemplateListResponse(ResponseModel):
pipeline_templates: list[PipelineTemplateItemResponse]
class PipelineTemplateDetailResponse(ResponseModel):
id: str
name: str
icon_info: dict[str, Any]
description: str
chunk_structure: str
export_data: str
graph: dict[str, Any]
created_by: str | None = None
class CustomizedPipelineTemplatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", max_length=400)
icon_info: dict[str, object] = Field(default_factory=lambda: IconInfo(icon="").model_dump())
register_schema_models(
console_ns,
CustomizedPipelineTemplatePayload,
PipelineTemplateDetailQuery,
PipelineTemplateListQuery,
)
register_response_schema_models(
console_ns,
PipelineTemplateDetailResponse,
PipelineTemplateListResponse,
SimpleDataResponse,
)
@console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource):
@console_ns.doc(params=query_params_from_model(PipelineTemplateListQuery))
@console_ns.response(200, "Pipeline templates", console_ns.models[PipelineTemplateListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str)
def get(self) -> JsonResponseWithStatus:
query = PipelineTemplateListQuery.model_validate(request.args.to_dict(flat=True))
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
return pipeline_templates, 200
pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language)
return dump_response(PipelineTemplateListResponse, pipeline_templates), 200
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
class PipelineTemplateDetailApi(Resource):
@console_ns.doc(params=query_params_from_model(PipelineTemplateDetailQuery))
@console_ns.response(200, "Pipeline template", console_ns.models[PipelineTemplateDetailResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, template_id: str):
type = request.args.get("type", default="built-in", type=str)
def get(self, template_id: str) -> JsonResponseWithStatus:
query = PipelineTemplateDetailQuery.model_validate(request.args.to_dict(flat=True))
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, query.type)
if pipeline_template is None:
return {"error": "Pipeline template not found from upstream service."}, 404
return pipeline_template, 200
class Payload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", max_length=400)
icon_info: dict[str, object] | None = None
register_schema_models(console_ns, Payload)
register_response_schema_models(console_ns, SimpleDataResponse)
raise NotFound("Pipeline template not found from upstream service.")
return dump_response(PipelineTemplateDetailResponse, pipeline_template), 200
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
class CustomizedPipelineTemplateApi(Resource):
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
@console_ns.response(204, "Pipeline template updated")
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
payload = Payload.model_validate(console_ns.payload or {})
def patch(self, template_id: str) -> tuple[str, int]:
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
return "", 204
@console_ns.response(204, "Pipeline template deleted")
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str):
def delete(self, template_id: str) -> tuple[str, int]:
RagPipelineService.delete_customized_pipeline_template(template_id)
return 200
return "", 204
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(200, "Success", console_ns.models[SimpleDataResponse.__name__])
def post(self, template_id: str):
def post(self, template_id: str) -> JsonResponseWithStatus:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
template = session.scalar(
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
@ -96,19 +156,20 @@ class CustomizedPipelineTemplateApi(Resource):
if not template:
raise ValueError("Customized pipeline template not found.")
return {"data": template.yaml_content}, 200
return dump_response(SimpleDataResponse, {"data": template.yaml_content}), 200
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource):
@console_ns.expect(console_ns.models[Payload.__name__])
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
@console_ns.response(204, "Pipeline template published")
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str):
payload = Payload.model_validate(console_ns.payload or {})
def post(self, pipeline_id: str) -> tuple[str, int]:
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
return {"result": "success"}
return "", 204

View File

@ -1,20 +1,25 @@
from flask_restx import Resource, marshal
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
from controllers.common.schema import register_schema_model
from controllers.common.schema import JsonResponseWithStatus, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import RagPipelineImportResponse
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_account_with_tenant, login_required
from fields.dataset_fields import DatasetDetailResponse
from libs.helper import dump_response
from libs.login import login_required
from models import Account
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -25,19 +30,26 @@ class RagPipelineDatasetImportPayload(BaseModel):
yaml_content: str
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
register_schema_models(console_ns, RagPipelineDatasetImportPayload)
register_response_schema_models(console_ns, DatasetDetailResponse, RagPipelineImportResponse)
@console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource):
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
@console_ns.response(
201,
"RAG pipeline dataset import started",
console_ns.models[RagPipelineImportResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
@ -70,19 +82,20 @@ class CreateRagPipelineDatasetApi(Resource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return import_info, 201
return dump_response(RagPipelineImportResponse, import_info), 201
@console_ns.route("/rag/pipeline/empty-dataset")
class CreateEmptyRagPipelineDatasetApi(Resource):
@console_ns.response(201, "RAG pipeline dataset created", console_ns.models[DatasetDetailResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
@ -99,4 +112,4 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
partial_member_list=None,
),
)
return marshal(dataset, dataset_detail_fields), 201
return dump_response(DatasetDetailResponse, dataset), 201

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from typing import Any, Concatenate, NoReturn
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
@ -56,7 +57,9 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
def _api_prerequisite[T, **P, R](
f: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -71,10 +74,10 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
return f(self, *args, **kwargs)
return wrapper
@ -168,21 +171,22 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: str):
def get(self, pipeline: Pipeline, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
return variable
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
def patch(self, pipeline: Pipeline, variable_id: UUID):
# Request payload for file types:
#
# Local File:
@ -210,11 +214,12 @@ class RagPipelineVariableApi(Resource):
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
@ -250,15 +255,16 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
def delete(self, pipeline: Pipeline, variable_id: str):
def delete(self, pipeline: Pipeline, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@ -267,7 +273,7 @@ class RagPipelineVariableApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class RagPipelineVariableResetApi(Resource):
@_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: str):
def put(self, pipeline: Pipeline, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -278,11 +284,12 @@ class RagPipelineVariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, pipeline_id={pipeline.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()

View File

@ -1,23 +1,29 @@
from flask import request
from flask_restx import Resource, fields, marshal_with # type: ignore
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.common.fields import SimpleDataResponse
from controllers.common.schema import (
JsonResponseWithStatus,
query_params_from_model,
register_response_schema_models,
register_schema_models,
)
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from core.plugin.entities.plugin import PluginDependency
from extensions.ext_database import db
from fields.rag_pipeline_fields import (
leaked_dependency_fields,
pipeline_import_check_dependencies_fields,
pipeline_import_fields,
)
from libs.login import current_account_with_tenant, login_required
from fields.base import ResponseModel
from libs.helper import dump_response
from libs.login import login_required
from models.account import Account
from models.dataset import Pipeline
from services.entities.dsl_entities import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -36,35 +42,45 @@ class RagPipelineImportPayload(BaseModel):
class IncludeSecretQuery(BaseModel):
include_secret: str = Field(default="false")
include_secret: str = Field(default="false", description="Whether to include secret values in the exported DSL")
class RagPipelineImportResponse(ResponseModel):
id: str
status: ImportStatus
pipeline_id: str | None = None
dataset_id: str | None = None
current_dsl_version: str
imported_dsl_version: str
error: str = ""
class RagPipelineImportCheckDependenciesResponse(ResponseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
fields.Nested(leaked_dependency_model)
)
pipeline_import_check_dependencies_model = get_or_create_model(
"RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
register_response_schema_models(
console_ns,
RagPipelineImportCheckDependenciesResponse,
RagPipelineImportResponse,
SimpleDataResponse,
)
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
@console_ns.response(200, "Import completed", console_ns.models[RagPipelineImportResponse.__name__])
@console_ns.response(202, "Import pending confirmation", console_ns.models[RagPipelineImportResponse.__name__])
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
@with_current_user
def post(self, current_user: Account) -> JsonResponseWithStatus:
# Check user role first
current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Use a plain Session so that caught exceptions inside the service
@ -91,23 +107,23 @@ class RagPipelineImportApi(Resource):
status = result.status
match status:
case ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return dump_response(RagPipelineImportResponse, result), 400
case ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return dump_response(RagPipelineImportResponse, result), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return result.model_dump(mode="json"), 200
return dump_response(RagPipelineImportResponse, result), 200
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
class RagPipelineImportConfirmApi(Resource):
@console_ns.response(200, "Import confirmed", console_ns.models[RagPipelineImportResponse.__name__])
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_model)
def post(self, import_id: str):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, import_id: str) -> JsonResponseWithStatus:
with Session(db.engine, expire_on_commit=False) as session:
import_service = RagPipelineDslService(session)
account = current_user
@ -119,34 +135,40 @@ class RagPipelineImportConfirmApi(Resource):
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
return dump_response(RagPipelineImportResponse, result), 400
return dump_response(RagPipelineImportResponse, result), 200
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
class RagPipelineImportCheckDependenciesApi(Resource):
@console_ns.response(
200,
"Dependencies checked",
console_ns.models[RagPipelineImportCheckDependenciesResponse.__name__],
)
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
with Session(db.engine, expire_on_commit=False) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
return result.model_dump(mode="json"), 200
return dump_response(RagPipelineImportCheckDependenciesResponse, result), 200
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
class RagPipelineExportApi(Resource):
@console_ns.doc(params=query_params_from_model(IncludeSecretQuery))
@console_ns.response(200, "Pipeline exported", console_ns.models[SimpleDataResponse.__name__])
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
def get(self, pipeline: Pipeline):
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
# Add include_secret params
query = IncludeSecretQuery.model_validate(request.args.to_dict())
@ -156,4 +178,4 @@ class RagPipelineExportApi(Resource):
pipeline=pipeline, include_secret=query.include_secret == "true"
)
return {"data": result}, 200
return dump_response(SimpleDataResponse, {"data": result}), 200

View File

@ -901,7 +901,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline, run_id: str):
def get(self, pipeline: Pipeline, run_id: UUID):
"""
Get workflow run node execution list
"""

View File

@ -20,6 +20,7 @@ from controllers.console.app.error import (
from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from graphon.model_runtime.errors.invoke import InvokeError
from models.model import InstalledApp
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -40,8 +41,10 @@ register_schema_model(console_ns, TextToAudioPayload)
endpoint="installed_app_audio",
)
class ChatAudioApi(InstalledAppResource):
def post(self, installed_app):
def post(self, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
file = request.files["file"]
@ -81,8 +84,10 @@ class ChatAudioApi(InstalledAppResource):
)
class ChatTextApi(InstalledAppResource):
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
def post(self, installed_app):
def post(self, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
try:
payload = TextToAudioPayload.model_validate(console_ns.payload or {})

View File

@ -18,6 +18,7 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user_id
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -31,7 +32,7 @@ from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
from models.model import AppMode
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
@ -83,8 +84,10 @@ register_response_schema_models(console_ns, SimpleResultResponse)
)
class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
def post(self, installed_app):
def post(self, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
@ -133,18 +136,18 @@ class CompletionApi(InstalledAppResource):
)
class CompletionStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, installed_app, task_id: str):
@with_current_user_id
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user.id,
user_id=current_user_id,
app_mode=AppMode.value_of(app_model.mode),
)
@ -157,8 +160,10 @@ class CompletionStopApi(InstalledAppResource):
)
class ChatApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
def post(self, installed_app):
def post(self, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -209,19 +214,19 @@ class ChatApi(InstalledAppResource):
)
class ChatStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, installed_app, task_id: str):
@with_current_user_id
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user.id,
user_id=current_user_id,
app_mode=app_mode,
)

View File

@ -8,6 +8,7 @@ from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
@ -20,7 +21,7 @@ from fields.conversation_fields import (
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
from models.model import AppMode
from models.model import AppMode, InstalledApp
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.web_conversation_service import WebConversationService
@ -44,8 +45,10 @@ register_response_schema_models(console_ns, ResultResponse)
)
class ConversationListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app):
def get(self, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -92,8 +95,10 @@ class ConversationListApi(InstalledAppResource):
)
class ConversationApi(InstalledAppResource):
@console_ns.response(204, "Conversation deleted successfully")
def delete(self, installed_app, c_id: UUID):
def delete(self, installed_app: InstalledApp, c_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -115,8 +120,10 @@ class ConversationApi(InstalledAppResource):
)
class ConversationRenameApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id: UUID):
def post(self, installed_app: InstalledApp, c_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -146,8 +153,10 @@ class ConversationRenameApi(InstalledAppResource):
)
class ConversationPinApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def patch(self, installed_app, c_id: UUID):
def patch(self, installed_app: InstalledApp, c_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -170,8 +179,10 @@ class ConversationPinApi(InstalledAppResource):
)
class ConversationUnPinApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def patch(self, installed_app, c_id: UUID):
def patch(self, installed_app: InstalledApp, c_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()

View File

@ -12,14 +12,19 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from libs.login import login_required
from models import Account, App, InstalledApp, RecommendedApp
from models.model import IconType
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
@ -131,9 +136,10 @@ class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
def get(self):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
if query.app_id:
installed_apps = db.session.scalars(
@ -212,7 +218,8 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.scalar(
@ -221,8 +228,6 @@ class InstalledAppsListApi(Resource):
if recommended_app is None:
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.get(App, payload.app_id)
if app is None:
@ -262,8 +267,8 @@ class InstalledAppApi(InstalledAppResource):
"""
@console_ns.response(204, "App uninstalled successfully")
def delete(self, installed_app):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")
@ -273,7 +278,7 @@ class InstalledAppApi(InstalledAppResource):
return "", 204
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
def patch(self, installed_app):
def patch(self, installed_app: InstalledApp):
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
commit_args = False

View File

@ -10,6 +10,7 @@ from controllers.common.controller_schemas import MessageFeedbackPayload, Messag
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
AppUnavailableError,
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
@ -21,15 +22,16 @@ from controllers.console.explore.error import (
NotCompletionAppError,
)
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.login import current_account_with_tenant
from models import Account
from models.enums import FeedbackRating
from models.model import AppMode
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
@ -59,9 +61,11 @@ register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsRe
)
class MessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -96,9 +100,11 @@ class MessageListApi(InstalledAppResource):
class MessageFeedbackApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
def post(self, installed_app, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
message_id_str = str(message_id)
@ -124,9 +130,11 @@ class MessageFeedbackApi(InstalledAppResource):
)
class MessageMoreLikeThisApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
def get(self, installed_app, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -170,9 +178,11 @@ class MessageMoreLikeThisApi(InstalledAppResource):
)
class MessageSuggestedQuestionApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
def get(self, installed_app, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()

View File

@ -7,11 +7,14 @@ from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.login import current_account_with_tenant
from models import Account
from models.model import InstalledApp
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@ -22,9 +25,11 @@ register_response_schema_models(console_ns, ResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -46,9 +51,11 @@ class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -67,9 +74,11 @@ class SavedMessageListApi(InstalledAppResource):
)
class SavedMessageApi(InstalledAppResource):
@console_ns.response(204, "Saved message deleted successfully")
def delete(self, installed_app, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
message_id_str = str(message_id)

View File

@ -13,6 +13,7 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -25,7 +26,7 @@ from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.login import current_account_with_tenant
from models import Account
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -41,11 +42,11 @@ register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
def post(self, installed_app: InstalledApp):
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp):
"""
Run workflow
"""
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()

View File

@ -9,14 +9,14 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator
from constants import HIDDEN_VALUE
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
from .wraps import account_initialization_required, setup_required, with_current_tenant_id
class CodeBasedExtensionQuery(BaseModel):
@ -116,11 +116,11 @@ class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
return [
_serialize_api_based_extension(extension)
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id)
]
@console_ns.doc("create_api_based_extension")
@ -130,9 +130,9 @@ class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_tenant_id,
@ -153,12 +153,12 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, id: UUID):
api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant()
return _serialize_api_based_extension(
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
)
@console_ns.doc("update_api_based_extension")
@ -169,9 +169,9 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, id: UUID):
@with_current_tenant_id
def post(self, current_tenant_id: str, id: UUID):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
@ -197,9 +197,9 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, id: UUID):
@with_current_tenant_id
def delete(self, current_tenant_id: str, id: UUID):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)

View File

@ -2,13 +2,36 @@ from flask_restx import Resource
from werkzeug.exceptions import Unauthorized
from controllers.common.schema import register_response_schema_models
from libs.login import current_account_with_tenant, current_user, login_required
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
from fields.base import ResponseModel
from libs.helper import dump_response
from libs.login import current_user, login_required
from services.feature_service import (
FeatureModel,
FeatureService,
LimitationModel,
SystemFeatureModel,
)
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
class TrialModelsResponse(ResponseModel):
trial_models: list[str]
class AppDslVersionResponse(ResponseModel):
app_dsl_version: str
register_response_schema_models(
console_ns,
AppDslVersionResponse,
FeatureModel,
LimitationModel,
SystemFeatureModel,
TrialModelsResponse,
)
@console_ns.route("/features")
@ -24,10 +47,9 @@ class FeatureApi(Resource):
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
@with_current_tenant_id
def get(self, current_tenant_id: str):
"""Get feature configuration for current tenant"""
_, current_tenant_id = current_account_with_tenant()
payload = FeatureService.get_features(
current_tenant_id,
exclude_vector_space=True,
@ -49,13 +71,49 @@ class FeatureVectorSpaceApi(Resource):
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
@with_current_tenant_id
def get(self, current_tenant_id: str):
"""Get vector-space usage and limit for current tenant"""
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_vector_space(current_tenant_id).model_dump()
@console_ns.route("/trial-models")
class TrialModelsApi(Resource):
@console_ns.doc("get_trial_models")
@console_ns.doc(description="Get hosted trial model provider configuration")
@console_ns.response(
200,
"Success",
console_ns.models[TrialModelsResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Get hosted trial model provider configuration for model-provider pages."""
return dump_response(
TrialModelsResponse,
{"trial_models": FeatureService.get_trial_models()},
)
@console_ns.route("/app-dsl-version")
class AppDslVersionApi(Resource):
@console_ns.doc("get_app_dsl_version")
@console_ns.doc(description="Get current app DSL version")
@console_ns.response(
200,
"Success",
console_ns.models[AppDslVersionResponse.__name__],
)
def get(self):
"""Get current app DSL version for workflow clipboard compatibility."""
return dump_response(
AppDslVersionResponse,
{"app_dsl_version": FeatureService.get_app_dsl_version()},
)
@console_ns.route("/system-features")
class SystemFeatureApi(Resource):
@console_ns.doc("get_system_features")

View File

@ -22,10 +22,13 @@ from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.file_fields import FileResponse, UploadConfig
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.file_service import FileService
from . import console_ns
@ -62,8 +65,8 @@ class FileApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("documents")
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@ -107,10 +110,10 @@ class FilePreviewApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
def get(self, file_id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, file_id: UUID):
file_id_str = str(file_id)
_, tenant_id = current_account_with_tenant()
text = FileService(db.engine).get_file_preview(file_id_str, tenant_id)
text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id)
return {"content": text}

View File

@ -12,8 +12,15 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
model_validate,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
@ -22,8 +29,8 @@ from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from libs.login import login_required
from models import Account, App
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
@ -33,6 +40,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
register_schema_models(console_ns, HumanInputFormSubmitPayload)
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
@ -45,9 +54,8 @@ class ConsoleHumanInputFormApi(Resource):
"""Console API for getting human input form definition."""
@staticmethod
def _ensure_console_access(form: Form):
_, current_tenant_id = current_account_with_tenant()
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
"""Ensure a console form token resolves only inside the current tenant."""
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@ -59,7 +67,8 @@ class ConsoleHumanInputFormApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, form_token: str):
@with_current_tenant_id
def get(self, current_tenant_id: str, form_token: str):
"""
Get human input form definition by form token.
@ -70,13 +79,23 @@ class ConsoleHumanInputFormApi(Resource):
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_access(form, current_tenant_id)
return _jsonify_form_definition(form)
@account_initialization_required
@login_required
def post(self, form_token: str):
@with_current_user
@with_current_tenant_id
@model_validate(HumanInputFormSubmitPayload)
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
def post(
self,
payload: HumanInputFormSubmitPayload,
current_tenant_id: str,
current_user: Account,
form_token: str,
):
"""
Submit human input form by form token.
@ -90,15 +109,12 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_access(form, current_tenant_id)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
# The type checker is not smart enought to validate the following invariant.
@ -122,7 +138,9 @@ class ConsoleWorkflowEventsApi(Resource):
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
"""
Get workflow execution events stream after resume.
@ -130,8 +148,6 @@ class ConsoleWorkflowEventsApi(Resource):
Returns Server-Sent Events stream.
"""
user, tenant_id = current_account_with_tenant()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(

View File

@ -8,8 +8,14 @@ from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_account_with_tenant, login_required
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_user,
)
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
# Notification content is stored under three lang tags.
@ -70,11 +76,10 @@ class NotificationApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, _ = current_account_with_tenant()
def get(self, current_user: Account):
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
@ -113,11 +118,11 @@ class NotificationDismissApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
def post(self, current_user: Account):
payload = DismissNotificationPayload.model_validate(request.get_json())
BillingService.dismiss_notification(
notification_id=payload.notification_id,

View File

@ -12,11 +12,13 @@ from controllers.common.errors import (
)
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from core.helper import ssrf_proxy
from controllers.console.wraps import with_current_user
from core.file import remote_fetcher
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.file_service import FileService
@ -34,9 +36,9 @@ class GetRemoteFileInfo(Resource):
@login_required
def get(self, url: str):
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = ssrf_proxy.head(decoded_url)
resp = remote_fetcher.make_request("HEAD", decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
@ -49,15 +51,16 @@ class RemoteFileUpload(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__])
@login_required
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = payload.url
# Try to fetch remote file metadata/content first
try:
resp = ssrf_proxy.head(url=url)
resp = remote_fetcher.make_request("HEAD", url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
# Normalize into a user-friendly error message expected by tests
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
@ -71,15 +74,14 @@ class RemoteFileUpload(Resource):
raise FileTooLargeError()
# Load content if needed
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
user=current_user,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:

View File

@ -9,9 +9,16 @@ from werkzeug.exceptions import Forbidden
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.enums import TagType
from services.tag_service import (
SaveTagPayload,
@ -92,8 +99,8 @@ class TagListApi(Resource):
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
)
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
@ -109,9 +116,9 @@ class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
@with_current_user
def post(self, current_user: Account):
# Allow users with edit permission, or dataset editors (including dataset operators).
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@ -132,8 +139,8 @@ class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, tag_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def patch(self, current_user: Account, tag_id: UUID):
tag_id_str = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -163,20 +170,19 @@ class TagUpdateDeleteApi(Resource):
return "", 204
def _require_tag_binding_edit_permission() -> None:
def _require_tag_binding_edit_permission(current_user: Account) -> None:
"""
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission(current_user)
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
@ -189,8 +195,8 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]:
return {"result": "success"}, 200
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission(current_user)
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
@ -213,8 +219,9 @@ class TagBindingCollectionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@with_current_user
def post(self, current_user: Account):
return _create_tag_bindings(current_user)
@console_ns.route("/tag-bindings/remove")
@ -228,5 +235,6 @@ class TagBindingRemoveApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
return _remove_tag_bindings()
@with_current_user
def post(self, current_user: Account):
return _remove_tag_bindings(current_user)

View File

@ -18,7 +18,7 @@ from controllers.common.fields import (
SimpleResultResponse,
VerificationTokenResponse,
)
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
@ -42,15 +42,17 @@ from controllers.console.wraps import (
enterprise_license_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.member_fields import Account as AccountResponse
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
from libs.login import login_required
from models import Account, AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
from models.enums import CreatorUserRole
from models.model import UploadFile
@ -173,7 +175,6 @@ class CheckEmailUniquePayload(BaseModel):
register_schema_models(
console_ns,
AccountResponse,
AccountInitPayload,
AccountNamePayload,
AccountAvatarPayload,
@ -245,6 +246,7 @@ register_schema_models(
)
register_response_schema_models(
console_ns,
AccountResponse,
AvatarUrlResponse,
SimpleResultDataResponse,
SimpleResultResponse,
@ -258,9 +260,8 @@ class AccountInitApi(Resource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@setup_required
@login_required
def post(self):
account, _ = current_account_with_tenant()
@with_current_user
def post(self, account: Account):
if account.status == "active":
raise AccountAlreadyInitedError()
@ -306,8 +307,8 @@ class AccountProfileApi(Resource):
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@enterprise_license_required
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
return _serialize_account(current_user)
@ -318,8 +319,8 @@ class AccountNameApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
@ -329,20 +330,21 @@ class AccountNameApi(Resource):
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
@console_ns.doc("get_account_avatar")
@console_ns.doc(description="Get account avatar url")
@console_ns.doc(params=query_params_from_model(AccountAvatarQuery))
@console_ns.response(200, "Success", console_ns.models[AvatarUrlResponse.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True))
avatar = args.avatar
if avatar.startswith(("http://", "https://")):
return {"avatar_url": avatar}
return dump_response(AvatarUrlResponse, {"avatar_url": avatar})
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
if upload_file is None:
@ -355,15 +357,15 @@ class AccountAvatarApi(Resource):
raise NotFound("Avatar file not found")
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {"avatar_url": avatar_url}
return dump_response(AvatarUrlResponse, {"avatar_url": avatar_url})
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountAvatarPayload.model_validate(payload)
@ -379,8 +381,8 @@ class AccountInterfaceLanguageApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountInterfaceLanguagePayload.model_validate(payload)
@ -396,8 +398,8 @@ class AccountInterfaceThemeApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountInterfaceThemePayload.model_validate(payload)
@ -413,8 +415,8 @@ class AccountTimezoneApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountTimezonePayload.model_validate(payload)
@ -430,8 +432,8 @@ class AccountPasswordApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountPasswordPayload.model_validate(payload)
@ -449,9 +451,8 @@ class AccountIntegrateApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account):
account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
).all()
@ -495,9 +496,8 @@ class AccountDeleteVerifyApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account):
token, code = AccountService.generate_account_deletion_verification_code(account)
AccountService.send_account_deletion_verification_email(account, code)
@ -511,9 +511,8 @@ class AccountDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
@with_current_user
def post(self, account: Account):
payload = console_ns.payload or {}
args = AccountDeletePayload.model_validate(payload)
@ -547,9 +546,8 @@ class EducationVerifyApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account):
return EducationVerifyResponse.model_validate(
BillingService.EducationIdentity.verify(account.id, account.email) or {}
).model_dump(mode="json")
@ -563,9 +561,8 @@ class EducationApi(Resource):
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
def post(self):
account, _ = current_account_with_tenant()
@with_current_user
def post(self, account: Account):
payload = console_ns.payload or {}
args = EducationActivatePayload.model_validate(payload)
@ -577,9 +574,8 @@ class EducationApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account):
res = BillingService.EducationIdentity.status(account.id) or {}
# convert expire_at to UTC timestamp from isoformat
if res and "expire_at" in res:
@ -613,8 +609,8 @@ class ChangeEmailSendEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = ChangeEmailSendPayload.model_validate(payload)
@ -673,8 +669,8 @@ class ChangeEmailCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
@ -720,7 +716,8 @@ class ChangeEmailResetApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
normalized_new_email = args.new_email.lower()
@ -731,7 +728,6 @@ class ChangeEmailResetApi(Resource):
if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
current_user, _ = current_account_with_tenant()
reset_data = AccountService.get_change_email_data(args.token)
if not reset_data:
raise InvalidTokenError()

View File

@ -1,9 +1,15 @@
from flask_restx import Resource, fields
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.agent_service import AgentService
@ -19,14 +25,10 @@ class AgentProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
user_id = user.id
tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
@ -42,6 +44,7 @@ class AgentProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))

View File

@ -14,10 +14,16 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from core.plugin.impl.exc import PluginPermissionDeniedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from services.plugin.endpoint_service import EndpointService
@ -96,17 +102,15 @@ register_schema_models(
)
def _create_endpoint() -> dict[str, bool]:
"""Create a plugin endpoint for the current workspace."""
user, tenant_id = current_account_with_tenant()
def _create_endpoint(tenant_id: str, user_id: str) -> dict[str, bool]:
"""Create a plugin endpoint for the injected workspace and user."""
args = EndpointCreatePayload.model_validate(console_ns.payload)
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
plugin_unique_identifier=args.plugin_unique_identifier,
name=args.name,
settings=args.settings,
@ -116,16 +120,14 @@ def _create_endpoint() -> dict[str, bool]:
raise ValueError(e.description) from e
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
def _update_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
"""Update a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
args = EndpointUpdatePayload.model_validate(console_ns.payload)
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
endpoint_id=endpoint_id,
name=args.name,
settings=args.settings,
@ -133,14 +135,12 @@ def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
}
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
def _delete_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
"""Delete a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
return {
"success": EndpointService.delete_endpoint(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
endpoint_id=endpoint_id,
)
}
@ -163,8 +163,10 @@ class EndpointCollectionApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
return _create_endpoint()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
@console_ns.route("/workspaces/current/endpoints/create")
@ -189,8 +191,10 @@ class DeprecatedEndpointCreateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
return _create_endpoint()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
@console_ns.route("/workspaces/current/endpoints/list")
@ -206,9 +210,9 @@ class EndpointListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def get(self, tenant_id: str, user_id: str):
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True))
page = args.page
@ -218,7 +222,7 @@ class EndpointListApi(Resource):
{
"endpoints": EndpointService.list_endpoints(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
page=page,
page_size=page_size,
)
@ -239,9 +243,9 @@ class EndpointListForSinglePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def get(self, tenant_id: str, user_id: str):
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True))
page = args.page
@ -252,7 +256,7 @@ class EndpointListForSinglePluginApi(Resource):
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
plugin_id=plugin_id,
page=page,
page_size=page_size,
@ -278,8 +282,10 @@ class EndpointItemApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, id: str):
return _delete_endpoint(endpoint_id=id)
@with_current_user_id
@with_current_tenant_id
def delete(self, tenant_id: str, user_id: str, id: str):
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
@console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint")
@ -295,8 +301,10 @@ class EndpointItemApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def patch(self, id: str):
return _update_endpoint(endpoint_id=id)
@with_current_user_id
@with_current_tenant_id
def patch(self, tenant_id: str, user_id: str, id: str):
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
@console_ns.route("/workspaces/current/endpoints/delete")
@ -322,9 +330,11 @@ class DeprecatedEndpointDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
args = EndpointIdPayload.model_validate(console_ns.payload)
return _delete_endpoint(endpoint_id=args.endpoint_id)
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
@console_ns.route("/workspaces/current/endpoints/update")
@ -350,9 +360,11 @@ class DeprecatedEndpointUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
return _update_endpoint(endpoint_id=args.endpoint_id)
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
@console_ns.route("/workspaces/current/endpoints/enable")
@ -370,14 +382,14 @@ class EndpointEnableApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.enable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
)
}
@ -397,13 +409,13 @@ class EndpointDisableApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.disable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
)
}

View File

@ -4,11 +4,16 @@ from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from libs.login import login_required
from models import Account, TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@ -29,8 +34,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str):
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
@ -72,8 +78,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()

View File

@ -4,6 +4,7 @@ from uuid import UUID
from flask import abort, request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import func, select
import services
from configs import dify_config
@ -22,15 +23,16 @@ from controllers.console.auth.error import (
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
is_allow_transfer_owner,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
from libs.login import login_required
from models.account import Account, TenantAccountJoin, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@ -76,7 +78,55 @@ register_response_schema_models(console_ns, SimpleResultDataResponse, Verificati
def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
if role != TenantAccountRole.DATASET_OPERATOR:
return True
return FeatureService.get_features(tenant_id=tenant_id).dataset_operator_enabled
return FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True).dataset_operator_enabled
def _normalize_invitee_emails(emails: list[str]) -> list[str]:
return list(dict.fromkeys(email.lower() for email in emails))
def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int:
new_member_count = 0
for email in emails:
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
new_member_count += 1
continue
exists = db.session.scalar(
select(TenantAccountJoin.id)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not exists:
new_member_count += 1
return new_member_count
def _count_current_members(tenant_id: str) -> int:
return (
db.session.scalar(select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.tenant_id == tenant_id)) or 0
)
def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None:
if new_member_count <= 0:
return
features = FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True)
if dify_config.ENTERPRISE_ENABLED:
workspace_members = features.workspace_members
if workspace_members.enabled is True and not workspace_members.is_available(new_member_count):
raise WorkspaceMembersLimitExceeded()
return
if dify_config.BILLING_ENABLED and features.billing.enabled is True:
members = features.members
current_member_count = _count_current_members(tenant_id)
if 0 < members.limit < current_member_count + new_member_count:
raise WorkspaceMembersLimitExceeded()
@console_ns.route("/workspaces/current/members")
@ -87,8 +137,8 @@ class MemberListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
@ -105,17 +155,16 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
invitee_emails = args.emails
invitee_emails = _normalize_invitee_emails(args.emails)
invitee_role = args.role
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
@ -130,37 +179,36 @@ class MemberInviteEmailApi(Resource):
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
tenant_id = inviter.current_tenant.id
with redis_client.lock(f"workspace_member_invite:{tenant_id}", timeout=60):
new_member_count = _count_new_member_invites(tenant_id, invitee_emails)
_check_member_invite_limits(tenant_id, new_member_count)
if not workspace_members.is_available(len(invitee_emails)):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
"email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
for invitee_email in invitee_emails:
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
return {
"result": "success",
@ -176,8 +224,8 @@ class MemberCancelInviteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, member_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, member_id: UUID):
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
@ -209,14 +257,14 @@ class MemberUpdateRoleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def put(self, member_id: UUID):
@with_current_user
def put(self, current_user: Account, member_id: UUID):
payload = console_ns.payload or {}
args = MemberRoleUpdatePayload.model_validate(payload)
new_role = args.role
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not _is_role_enabled(new_role, current_user.current_tenant.id):
@ -250,8 +298,8 @@ class DatasetOperatorMemberListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@ -270,13 +318,13 @@ class SendOwnerTransferEmailApi(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = OwnerTransferEmailPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
@ -308,11 +356,11 @@ class OwnerTransferCheckApi(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = OwnerTransferCheckPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -352,12 +400,12 @@ class OwnerTransfer(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id: UUID):
@with_current_user
def post(self, current_user: Account, member_id: UUID):
payload = console_ns.payload or {}
args = OwnerTransferPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@ -95,10 +102,8 @@ class ModelProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
@with_current_tenant_id
def get(self, tenant_id: str):
payload = request.args.to_dict(flat=True)
args = ParserModelList.model_validate(payload)
@ -114,9 +119,8 @@ class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
# if credential_id is not provided, return current used credential
payload = request.args.to_dict(flat=True)
args = ParserCredentialId.model_validate(payload)
@ -133,8 +137,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialCreate.model_validate(payload)
@ -157,9 +161,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def put(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialUpdate.model_validate(payload)
@ -184,8 +187,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialDelete.model_validate(payload)
@ -205,8 +208,8 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialSwitch.model_validate(payload)
@ -225,8 +228,8 @@ class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialValidate.model_validate(payload)
@ -280,11 +283,8 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserPreferredProviderType.model_validate(payload)
@ -301,10 +301,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
data = BillingService.get_model_provider_payment_link(
provider_name=provider,

View File

@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@ -138,9 +145,8 @@ class DefaultModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True))
model_provider_service = ModelProviderService()
@ -156,9 +162,8 @@ class DefaultModelApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
args = ParserPostDefault.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_settings = args.model_settings
@ -189,9 +194,8 @@ class ModelProviderModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@ -202,9 +206,9 @@ class ModelProviderModelApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
# To save the model's load balance configs
_, tenant_id = current_account_with_tenant()
args = ParserPostModels.model_validate(console_ns.payload)
if args.config_from == "custom-model":
@ -249,9 +253,8 @@ class ModelProviderModelApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, tenant_id: str, provider: str):
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -268,9 +271,9 @@ class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
model_provider_service = ModelProviderService()
@ -292,9 +295,13 @@ class ModelProviderModelCredentialApi(Resource):
)
if args.config_from == "predefined-model":
# Only the predefined-model branch needs visibility filtering by user.
# The account is injected once by the handler and only passed into the
# service branch that needs user-scoped credential visibility.
available_credentials = model_provider_service.get_provider_available_credentials(
tenant_id=tenant_id,
provider=provider,
user=user,
)
else:
available_credentials = model_provider_service.get_provider_model_available_credentials(
@ -323,9 +330,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
args = ParserCreateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -355,8 +361,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def put(self, current_tenant_id: str, provider: str):
args = ParserUpdateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -382,8 +388,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider: str):
args = ParserDeleteCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -406,8 +412,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
args = ParserSwitch.model_validate(console_ns.payload)
service = ModelProviderService()
@ -430,9 +436,8 @@ class ModelProviderModelEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def patch(self, tenant_id: str, provider: str):
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -452,9 +457,8 @@ class ModelProviderModelDisableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def patch(self, tenant_id: str, provider: str):
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -480,8 +484,8 @@ class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
args = ParserValidate.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -515,9 +519,9 @@ class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
args = ParserParameter.model_validate(request.args.to_dict(flat=True))
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
@ -532,8 +536,8 @@ class ModelProviderAvailableModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, model_type: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, model_type: str):
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@ -69,6 +69,7 @@ class BuiltinToolAddPayload(BaseModel):
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
type: CredentialType
visibility: str | None = None
class BuiltinToolUpdatePayload(BaseModel):
@ -277,7 +278,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
@ -293,7 +294,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -306,7 +307,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
@ -324,7 +325,7 @@ class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -338,6 +339,7 @@ class ToolBuiltinProviderAddApi(Resource):
credentials=payload.credentials,
name=payload.name,
api_type=CredentialType.of(payload.type),
visibility=payload.visibility,
)
@ -348,7 +350,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -370,13 +372,20 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
_, tenant_id = current_account_with_tenant()
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
# Optional list of credential IDs to include even if visibility would hide them
# (used when a workflow/agent node still references another member's only_me credential).
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
]
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
user=user,
include_credential_ids=include_credential_ids or None,
)
)
@ -384,7 +393,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
def get(self, provider: str):
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@ -784,7 +793,7 @@ class ToolPluginOAuthApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
@ -822,7 +831,7 @@ class ToolPluginOAuthApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
def get(self, provider: str):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
@ -859,7 +868,7 @@ class ToolOAuthCallback(Resource):
if not credentials:
raise Exception("the plugin credentials failed")
# add credentials to database
# add credentials to database — OAuth tokens default to only_me since they're personal
BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
@ -867,6 +876,7 @@ class ToolOAuthCallback(Resource):
credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2,
visibility="only_me",
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@ -878,7 +888,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider(
@ -910,7 +920,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -919,7 +929,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -931,7 +941,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
@ -945,13 +955,18 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
_, tenant_id = current_account_with_tenant()
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
]
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
tenant_id=tenant_id,
provider=provider,
user=user,
include_credential_ids=include_credential_ids or None,
)
)
@ -1151,7 +1166,7 @@ class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
def get(self, provider_id: str):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
@ -1180,7 +1195,7 @@ class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
def get(self, provider_id: str):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)

View File

@ -77,7 +77,7 @@ class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
@ -103,7 +103,7 @@ class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -119,15 +119,18 @@ class TriggerSubscriptionListApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
user=user,
)
)
except ValueError as e:
@ -146,7 +149,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -175,7 +178,7 @@ class TriggerSubscriptionBuilderGetApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
def get(self, provider: str, subscription_builder_id: str):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
@ -191,7 +194,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
def post(self, provider: str, subscription_builder_id: str):
"""Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -223,7 +226,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
def post(self, provider: str, subscription_builder_id: str):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -257,7 +260,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
def get(self, provider: str, subscription_builder_id: str):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -280,7 +283,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
def post(self, provider: str, subscription_builder_id: str):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -404,7 +407,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -486,7 +489,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
def get(self, provider: str):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
@ -554,7 +557,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""Get OAuth client configuration for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -600,7 +603,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
"""Configure custom OAuth client for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -626,7 +629,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider):
def delete(self, provider: str):
"""Remove custom OAuth client configuration"""
user = current_user
assert user.current_tenant_id is not None
@ -654,7 +657,7 @@ class TriggerSubscriptionVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_id):
def post(self, provider: str, subscription_id: str):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None

View File

@ -25,13 +25,15 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
only_edition_enterprise,
setup_required,
with_current_tenant_id,
with_current_user,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from libs.helper import TimestampField, dump_response, to_timestamp
from libs.login import login_required
from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
@ -56,6 +58,11 @@ class WorkspaceCustomConfigPayload(BaseModel):
replace_webapp_logo: str | None = None
class WorkspaceCustomConfigResponse(ResponseModel):
remove_webapp_brand: bool | None = None
replace_webapp_logo: str | None = None
class WorkspaceInfoPayload(BaseModel):
name: str
@ -69,7 +76,7 @@ class TenantInfoResponse(ResponseModel):
role: str | None = None
in_trial: bool | None = None
trial_end_reason: str | None = None
custom_config: dict | None = None
custom_config: WorkspaceCustomConfigResponse | None = None
trial_credits: int | None = None
trial_credits_used: int | None = None
next_credit_reset_date: int | None = None
@ -101,9 +108,13 @@ register_schema_models(
SwitchWorkspacePayload,
WorkspaceCustomConfigPayload,
WorkspaceInfoPayload,
TenantInfoResponse,
)
register_response_schema_models(console_ns, WorkspacePermissionResponse)
register_response_schema_models(
console_ns,
TenantInfoResponse,
WorkspaceCustomConfigResponse,
WorkspacePermissionResponse,
)
provider_fields = {
"provider_name": fields.String,
@ -144,8 +155,9 @@ class TenantListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
@ -166,10 +178,10 @@ class TenantListApi(Resource):
if tenant_plan:
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
else:
features = FeatureService.get_features(tenant.id)
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
elif not is_enterprise_only:
features = FeatureService.get_features(tenant.id)
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
# Create a dictionary with tenant attributes
@ -219,11 +231,11 @@ class TenantApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
def post(self):
@with_current_user
def post(self, current_user: Account):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
current_user, _ = current_account_with_tenant()
tenant = current_user.current_tenant
if not tenant:
raise ValueError("No current tenant")
@ -238,13 +250,7 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
return (
TenantInfoResponse.model_validate(
WorkspaceService.get_tenant_info(tenant),
from_attributes=True,
).model_dump(mode="json"),
200,
)
return dump_response(TenantInfoResponse, WorkspaceService.get_tenant_info(tenant)), 200
@console_ns.route("/workspaces/switch")
@ -253,8 +259,8 @@ class SwitchWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = SwitchWorkspacePayload.model_validate(payload)
@ -278,8 +284,8 @@ class CustomConfigWorkspaceApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = console_ns.payload or {}
args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
@ -305,8 +311,8 @@ class WebappLogoWorkspaceApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -346,8 +352,8 @@ class WorkspaceInfoApi(Resource):
@login_required
@account_initialization_required
# Change workspace name
def post(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = console_ns.payload or {}
args = WorkspaceInfoPayload.model_validate(payload)
@ -369,13 +375,12 @@ class WorkspacePermissionApi(Resource):
@login_required
@account_initialization_required
@only_edition_enterprise
def get(self):
@with_current_tenant_id
def get(self, current_tenant_id: str):
"""
Get workspace permission settings.
Returns permission flags that control workspace features like member invitations and owner transfer.
"""
_, current_tenant_id = current_account_with_tenant()
if not current_tenant_id:
raise ValueError("No current tenant")

Some files were not shown because too many files have changed in this diff Show More