Compare commits

..

324 Commits

Author SHA1 Message Date
ec8b5f23d3 fix: nextjs security update (e-260) (#29553) 2025-12-12 11:42:48 +08:00
Joe
173110e04d fix: dataset editor 2025-06-20 16:10:22 +08:00
63f3af8bc4 feat: use default access mode when importing dsl (#21231) 2025-06-19 17:15:59 +08:00
3e60e682d1 Fix/webapp auth failed (#21149) 2025-06-18 11:25:45 +08:00
0c01f7498d Feat/webapp verified sso 260 (#20815) 2025-06-09 15:11:30 +09:00
c7d4026800 fix: remove all app token when logout 2025-06-06 15:53:40 +08:00
512c1938c1 Feat/webapp verified sso 260: fetch previous app session in public token exchange (#20740) 2025-06-06 16:52:15 +09:00
78cf376872 Feat/webapp verified sso 260: bad import path (#20734) 2025-06-06 16:09:45 +09:00
e312894bc9 Feat/webapp verified sso 260: add token exchange for public app (#20731) 2025-06-06 15:49:08 +09:00
26f291396d Fix/webapp no permission page 260 (#20730) 2025-06-06 14:27:25 +08:00
4835d78529 Merge tag '0.15.8' into e-260
0.15.8
2025-06-06 12:26:42 +08:00
05b746b350 Feat/webapp verified sso 260 (#20690) 2025-06-05 18:36:59 +09:00
94289b8af9 Feat/webapp verified sso 260 (#20684) 2025-06-05 17:31:08 +09:00
dcf4e5a30f Feat/webapp verified sso 260 (#20678) 2025-06-05 16:17:44 +09:00
05903e3251 Feat/webapp verified sso 260 (#20496) 2025-06-05 16:00:37 +09:00
1357999a4c fix: merge web app access scope control (#20675) 2025-06-05 14:37:35 +08:00
4b938ab18d chore: Bump version
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-05-30 16:25:40 +08:00
88356de923 fix: Refactor web reader to use readabilipy (#19789)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-05-30 16:23:17 +08:00
5f09900dca chore(api): Upgrade dependencies (#19736)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-05-15 14:47:15 +08:00
9ac99abf20 docs(CHANGELOG): Update CHANGELOG
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-05-14 18:03:05 +08:00
32588f562e feat(model): fix and re-add gpt-4.1.
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-05-14 18:02:32 +08:00
36f8bd3f1a chore: frontend third-part package security issue (#19655) 2025-05-14 14:08:05 +08:00
4466088f2e fix: invitations get suspended when an existing member appears (#19585) 2025-05-13 13:54:01 +08:00
c919074e06 docs(CHANGELOG.md): Update CHANGELOG.md
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-05-13 10:31:40 +08:00
88cd9aedb7 add gunicorn keepalive setting (#19537)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Bowen Liang <liang.bowen.123@qq.com>
2025-05-13 10:28:13 +08:00
16a4f77fb4 fix(config): Allow DB_EXTRAS to set search_path via options
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-05-13 10:19:08 +08:00
3401c52665 chore(pyproject.toml): Upgrade huggingface-hub, transformers and resend (#19563)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-05-12 23:21:57 +08:00
bc882ac4a1 fix: only owner can edit members in workspace (#19321) 2025-05-07 14:16:54 +08:00
1c2e8e1ce7 fix removing member without permission (#16332) (#19275)
Co-authored-by: Linh Nguyen <55907715+batman0911@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2025-05-06 15:41:50 +08:00
33d2c9d2ca Merge branch 'release/0.15-support' into e-260 2025-04-28 18:18:54 +08:00
4fa3d78ed8 Revert "feat : add GPT4.1 in the model providers" (#19002) 2025-04-28 18:15:24 +08:00
849994d35e Merge tag '0.15.7' into e-260
0.15.7
2025-04-28 17:17:26 +08:00
2fce4a338c fix: get realtime groups and members data every time user open the di… (#18988) 2025-04-28 17:01:07 +08:00
5f7f851b17 fix: Refines None checks in result transformation
Simplifies the code by replacing type checks for None with
direct comparisons, improving readability and consistency in
handling None values during output validation.

Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-04-28 15:40:14 +08:00
559ab46ee1 fix: Removes redundant token calculations and updates dependencies
Eliminates unnecessary pre-calculation of token limits and recalculation of max tokens
across multiple app runners, simplifying the logic for prompt handling.

Updates tiktoken library from version 0.8.0 to 0.9.0 for improved tokenization performance.

Increases default token limit in TokenBufferMemory to accommodate larger prompt messages.

These changes streamline the token management process and leverage the latest
improvements in the tiktoken library.

Fixes potential token overflow issues and prepares the system for handling larger
inputs more efficiently.

Relates to internal optimization tasks.

Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-04-28 15:39:12 +08:00
df98223c8c chore: Updates to version 0.15.7 with new model support
Adds support for GPT-4.1 and Amazon Bedrock DeepSeek-R1 models.
Fixes issues with app creation from template categories and
DSL version checks.

Updates version numbers in configuration files and Docker
setup to 0.15.7 for consistency.

Addresses issues #18807, #18868, #18872, #18878, and #18912.

Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-04-28 14:19:07 +08:00
144f9507f8 feat : add GPT4.1 in the model providers (#18912) 2025-04-27 19:31:20 +08:00
2e097a1ac0 add bedrock deepseek-r1 (#18908) 2025-04-27 19:30:42 +08:00
9f7d8a981f Patch: hotfix/create from template category (#18807) (#18868) 2025-04-27 14:47:18 +08:00
c4729f8c20 fix: check dsl version when create app from explore template (#18872)… (#18881)
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
2025-04-27 14:32:28 +08:00
5cb1cf9eca Patch: Hotfix/create from template category (#18807) (#18869) 2025-04-27 14:28:05 +08:00
40b31bafd5 fix: check dsl version when create app from explore template (#18872) (#18878) 2025-04-27 14:21:45 +08:00
d38a2c95fb docs(CHANGELOG): Update CHANGELOG.md
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-04-25 18:31:08 +08:00
7d18e2a0ef feat(app_dsl_service): Refines version compatibility logic
Updates logic to handle various version comparisons, ensuring
more precise status returns based on version differences.
Improves handling of older and newer versions to prevent
mismatches and ensure appropriate compatibility status.

Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-04-25 18:27:31 +08:00
024f242251 add bedrock claude-sonnet-3.7 (#18788) 2025-04-25 17:35:12 +08:00
de14a55bde fix: i18n update (#18787) 2025-04-25 17:28:32 +08:00
cbb1d722a5 fix: break switch logic if the sso protocol is empty (#18783) 2025-04-25 17:19:44 +08:00
1769ce16f3 fix: disable batch run button when user has no permission. (#18777) 2025-04-25 16:46:58 +08:00
170139bb0f fix: update sso protocol default value to '' (#18773) 2025-04-25 16:20:20 +08:00
ede0deb447 Fix/web app signin error (#18765) 2025-04-25 15:34:18 +08:00
d40f2e7d94 fix: web app login show undefined error message (#18757) 2025-04-25 14:09:38 +08:00
70ebfc064b fix: stop auto retry login when web app return error (#18747) 2025-04-25 12:09:13 +08:00
d6c252d77e Merge branch 'feat/webapp-auth-api' into e-260 2025-04-24 23:48:25 -04:00
fc3d3e0565 fix: wrong web sso protocal source in json 2025-04-24 23:48:18 -04:00
b786bbdab5 fix: add workspace limitation in invite-login API (#18724) 2025-04-25 09:52:47 +08:00
f45321dd27 fix: handle WorkspacesLimitExceededError in forgot_password.py (#18716) 2025-04-24 18:41:10 +08:00
746d4d8ead fix: update i18n (#18711) 2025-04-24 18:14:03 +08:00
7c31e3b6ba Hotfix/revert webapp login page (#18706) 2025-04-24 17:54:03 +08:00
7c1116f139 update. 2025-04-24 15:27:04 +08:00
b82cc1c2e8 feat: priced limit (#17683) 2025-04-24 14:58:34 +08:00
fee51ba994 Feat/e permission (#18656) 2025-04-24 13:10:01 +08:00
2259dfdc58 Merge branch 'feat/webapp-auth-api' into e-260 2025-04-23 23:10:02 -04:00
3761944a3f fix: remove debug logs 2025-04-23 23:09:45 -04:00
a239e756b0 Merge tag '0.15.6' into e-260
0.15.6
2025-04-23 22:41:12 -04:00
ac54dd89f4 fix: change rel url value to target_ref 2025-04-23 22:39:21 -04:00
5310ed4b54 Merge branch 'feat/webapp-auth-api' into e-260 2025-04-23 22:38:04 -04:00
bfdce78ca5 chore(*): Bump up to 0.15.6
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-04-23 14:06:46 +08:00
00c2258352 CHANGELOG): Adds initial changelog for version 0.15.6
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-04-23 13:55:33 +08:00
09f8da1429 fix: allow empty list api 2025-04-22 22:20:29 -04:00
9f07584a00 Feat/e license limit (#18436)
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
2025-04-23 00:23:38 +08:00
a1b3d41712 fix: clickjacking (#18552) 2025-04-22 17:08:52 +08:00
fcc274d679 fix: add filter in installedapp list api 2025-04-22 02:54:30 -04:00
14f378bbc6 Merge branch 'feat/webapp-auth-api' into e-260 2025-04-21 22:18:35 -04:00
669fb6be0f fix: wrong field name 2025-04-21 22:18:16 -04:00
724ffe55c9 fix: add back sso system feature 2025-04-21 22:02:50 -04:00
bfa5828259 fix: temp fix for unauthorized user in explore page 2025-04-21 19:40:51 -04:00
455d14296f fix: get app id from upstream decorator 2025-04-21 19:03:10 -04:00
d1a25e54e5 fix: add logging 2025-04-21 18:48:24 -04:00
9462ed7bbf fix: add auth constraint to explore apps 2025-04-21 18:47:24 -04:00
c6e63ac816 Revert "fix: update webapp auth api path"
This reverts commit a27db51b83.
2025-04-21 02:07:43 -04:00
a27db51b83 fix: update webapp auth api path 2025-04-21 02:06:07 -04:00
e52a9fbfb7 fix: remove curr user in webapp permission api 2025-04-20 23:33:51 -04:00
2af1dd6de3 feat: add webapp auth apis 2025-04-20 23:30:59 -04:00
b26e20fe34 fix: fix vertex gemini 2.0 flash 001 schema (#18405)
Co-authored-by: achmad-kautsar <achmad.kautsar@insignia.co.id>
2025-04-19 22:04:13 +08:00
161ff432f1 fix: update reset password token when email code verify success (#18362) 2025-04-18 17:15:15 +08:00
509733fbf0 fix: update reset password token when email code verify success (#18367) 2025-04-18 17:15:02 +08:00
99a9def623 fix: reset_password security issue (#18366) 2025-04-18 05:04:44 -04:00
7770a45253 fix: add password security update 2025-04-18 05:02:26 -04:00
bafdbade52 fix: wrong json structure 2025-04-11 17:19:34 -04:00
fa76590c24 chore: add log 2025-04-11 16:59:52 -04:00
d5b75470e4 fix: bad request 2025-04-11 16:48:09 -04:00
5f87bdbe3a fix: add batch get access mode api 2025-04-11 15:24:32 -04:00
cb13b53ccd fix: update webapp sso features 2025-04-11 03:25:58 -04:00
a1dc3cfdec fix: update code for access denied error 2025-04-11 02:45:46 -04:00
7a4ec9cf23 fix: change error code for webapp auth 2025-04-11 02:41:02 -04:00
4785c061a9 feat: add webapp clean up 2025-04-10 15:19:28 -04:00
4105c8ff70 fix: bad api call 2025-04-10 06:27:00 -04:00
b922c8c215 fix: make app private when created 2025-04-10 00:36:35 -04:00
cbea30e65f fix: bad field name 2025-04-09 17:21:16 -04:00
e9a207b38e fix: adjust enterprise api 2025-04-09 16:30:41 -04:00
5e50570739 fix: update webapp jwt claim and add user accessibility support 2025-04-07 18:41:02 -04:00
46d43e6758 feat: add web app auth 2025-04-07 17:03:26 -04:00
fe1846c437 fix: change gemini-2.0-flash to validate google api #17082 (#17115) 2025-03-30 13:04:12 +08:00
1045f6db7a fix: wrong arg parsing 2025-03-26 01:37:45 -04:00
50d36612f0 fix: bad import 2025-03-26 00:34:04 -04:00
e38631db8a feat: add inner mail api 2025-03-25 21:47:30 -04:00
7f63cd52a2 update. 2025-03-24 23:08:54 +08:00
8e75eb5c63 fix: update version to 0.15.5 in packaging and docker-compose files
Sgned-off-by: -LAN- <lapz8200@outlook.com>
2025-03-24 16:47:06 +08:00
970508fcb6 fix: update GitHub Actions workflow to trigger on tags
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-24 16:45:29 +08:00
5b357fdbf0 Merge branch 'release/0.15.5' into e-0154 2025-03-24 16:42:11 +08:00
9283a5414f fix: update yarn.lock 2025-03-24 16:41:07 +08:00
8923e64b8d Merge branch 'release/0.15.5' into e-0154 2025-03-24 15:40:32 +08:00
2a2a0e9be9 fix: update DifySandbox image version to 0.2.11 in docker-compose files
Sgned-off-by: -LAN- <laipz8200@outlook.com>
2025-03-24 15:37:55 +08:00
061a765b7d fix: sanitizer svg to avoid xss (#16608) 2025-03-24 14:48:40 +08:00
acd7fead87 feat: remove Vanna provider and associated assets from the project
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-24 14:34:03 +08:00
64e9d96d84 chore: compatible with es5 (#14268) 2025-03-24 13:17:48 +08:00
d27de3818c Merge branch 'release/0.15.5' into e-0154 2025-03-24 11:46:30 +08:00
bbb080d5b2 fix: update chatbot help doc link on the create app form 2025-03-24 11:28:35 +08:00
8c025abb3b Merge branch 'release/0.15.5' into e-0154 2025-03-24 10:32:56 +08:00
c01d8a70f3 fix: upgrade nextjs to v14.2.25. a security patch for CVE-2025-29927. 2025-03-24 10:32:18 +08:00
98606ca558 fix: upgrade nextjs to v14.2.25 2025-03-24 10:12:21 +08:00
adf3e18ebd Merge tag '0.15.4' into e-0154 2025-03-21 18:29:43 +08:00
1ca15989e0 chore: update version to 0.15.4 in configuration and docker files
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-21 16:39:06 +08:00
8b5a3a9424 Merge branch 'release/0.15.4' of github.com:langgenius/dify into release/0.15.4 2025-03-21 16:31:06 +08:00
42ddcf1edd chore: remove 0.15.3 branch config in the build action
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-21 16:30:33 +08:00
21561df10f fix: xss in render svg (#16437) 2025-03-21 15:24:58 +08:00
4327ec8c4c fix license expireAt field typo (#16428) 2025-03-21 13:43:43 +08:00
bbc5ec8301 fix: expired date calc error 2025-03-21 11:00:07 +08:00
4a51a72c1d Merge branch 'e-0154' into deploy/enterprise 2025-03-20 17:34:52 +08:00
4b6adffa8e fix: hide copyright on forgot-password/install/reset-password page 2025-03-20 17:34:19 +08:00
c7fd73d330 Merge branch 'e-0154' into deploy/enterprise 2025-03-20 10:13:09 +08:00
8a709e445a fix: remove Dify from Service API doc 2025-03-20 10:12:27 +08:00
f02b77b99f fix: Decouple login page logo component to avoid conflict with internal logo 2025-03-20 10:11:26 +08:00
abc625bcce Merge branch 'e-0154' into deploy/enterprise 2025-03-18 22:35:39 -04:00
b6bc1f8bc4 fix: adjust logic for branding toggle 2025-03-18 22:35:27 -04:00
b8f9037cd3 Merge branch 'e-0154' into deploy/enterprise 2025-03-18 16:13:14 +08:00
02606ba3c7 fix: cannot update webapp copyright info 2025-03-18 16:12:52 +08:00
79311d3fb5 Merge branch 'e-0154' into deploy/enterprise 2025-03-18 03:53:18 -04:00
31086a1fbf feat: add webapp copyright feature 2025-03-18 03:53:07 -04:00
6ae5d052e5 Merge branch 'e-0154' into deploy/enterprise 2025-03-18 14:55:36 +08:00
c794ecf101 fix: user can edit webapp copyright info only if webapp_copyright_enabled is true 2025-03-18 14:54:34 +08:00
d887aae012 Merge branch 'e-0154' into deploy/enterprise 2025-03-18 01:55:38 -04:00
1b1e96eff7 fix: typo 2025-03-18 01:55:27 -04:00
eecd091063 Merge branch 'e-0154' into deploy/enterprise 2025-03-17 15:34:49 -04:00
d38f2cb380 fix: change subject title 2025-03-17 15:34:28 -04:00
56aaee5558 fix: wrong branding title 2025-03-17 15:01:31 -04:00
d72b4752c9 fix: wrong title location 2025-03-17 15:00:04 -04:00
ea769c6483 Merge branch 'e-0154' into deploy/enterprise 2025-03-17 14:24:00 -04:00
ec194fa3d4 fix: invalid email template variables 2025-03-17 14:23:46 -04:00
b877039859 Merge branch 'e-0154' into deploy/enterprise 2025-03-17 10:37:20 +08:00
54634f26d2 fix: show copyright in webapp 2025-03-17 10:36:51 +08:00
3bef91a2cd fix: show loading icon when fetching system features 2025-03-15 12:01:30 +08:00
7da45ba589 fix: show loading icon when fetching system features 2025-03-15 12:00:22 +08:00
e0232c67cc fix: update document title and favicon in client side 2025-03-15 12:00:22 +08:00
1dc4a229d4 Merge branch 'e-0154' into deploy/enterprise 2025-03-14 16:37:02 -04:00
0e0bada1f3 fix: missing json keys 2025-03-14 16:36:49 -04:00
5366a814f9 fix: update json keys 2025-03-14 16:35:05 -04:00
f1240a22db fix: remove default value 2025-03-14 13:26:44 -04:00
66f35c2b7e Merge branch 'e-0154' into deploy/enterprise 2025-03-15 01:25:15 +08:00
766ee48531 fix: update document title and favicon in client side 2025-03-15 01:25:04 +08:00
083045f45c Merge branch 'e-0154' into deploy/enterprise 2025-03-14 20:49:17 +08:00
fe237802c9 fix: update Dify text 2025-03-14 19:10:03 +08:00
00b923651f fix: update document title with system features config 2025-03-14 19:10:03 +08:00
24fce3cc64 chore: use global zustand manage systemFeatures and share between all pages 2025-03-14 19:10:03 +08:00
8ba969f67d fix: add ci workflow 2025-03-13 17:15:11 -04:00
6844d59371 fix: add default title name 2025-03-13 17:07:45 -04:00
fe5529db85 Trigger workflow 2025-03-13 17:04:13 -04:00
d89034d913 feat: add application title 2025-03-13 15:49:04 -04:00
360fbeb108 fix: update email template, add application_title 2025-03-13 17:28:49 +08:00
e7c2fa1cfa fix: remove system feature is_branding 2025-03-12 10:48:58 -04:00
735f09d977 fix: build failed due to getPrevChatList no longer exists (#13383) 2025-03-12 10:22:33 +08:00
f83a5e3e49 fix: wrong type 2025-03-11 07:46:48 -04:00
01a8d4efcc fix: remove dify from invite template 2025-03-11 19:25:30 +08:00
fdb1e649d4 feat: add branding support 2025-03-11 07:14:52 -04:00
0856792a57 fix: add email templates that are no brands or logo 2025-03-11 16:03:15 +08:00
0e33a3aa5f chore: add ci 2025-02-19 14:34:36 +08:00
d3895bcd6b revert 2025-02-19 14:32:28 +08:00
eeb390650b fix: build failed 2025-02-19 14:32:28 +08:00
ca19bd31d4 chore(*): Bump version to 0.15.3 (#13308)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-07 15:20:05 +08:00
413dfd5628 feat: add completion mode and context size options for LLM configuration (#13325)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-07 15:08:53 +08:00
f9515901cc fix: Azure AI Foundry model cannot be used in the workflow (#13323)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-07 14:52:57 +08:00
3f42fabff8 chore:improve thinking display for llm from xinference and ollama pro… (#13318) 2025-02-07 14:29:29 +08:00
1caa578771 chore(*): Update style of thinking (#13319)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-07 14:06:35 +08:00
b7c11c1818 Fix the problem of Workflow terminates after parallel tasks execution, merge node not triggered (#12498)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2025-02-07 13:56:08 +08:00
3eb3db0663 chore: refactor the OpenAICompatible and improve thinking display (#13299) 2025-02-07 13:28:46 +08:00
be46f32056 fix(credits): require model name equals (#13314)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-07 13:28:17 +08:00
6e5c915f96 feat(model): add deepseek-r1 for openrouter (#13312) 2025-02-07 12:39:13 +08:00
04d13a8116 feat(credits): Allow to configure model-credit mapping (#13274)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-07 11:01:31 +08:00
e638ede3f2 Update README_TR.md (#13294) 2025-02-07 09:11:39 +08:00
2348abe4bf feat: added a couple of models not defined in vertex ai, that were already … (#13296) 2025-02-07 09:11:25 +08:00
f7e7a399d9 feat:add think tag display for xinference deepseek r1 (#13291) 2025-02-06 22:04:58 +08:00
ba91f34636 fix: incorrect transferMethod assignment for remote file (#13286) 2025-02-06 19:32:21 +08:00
16865d43a8 feat: add deepseek models for volcengine provider (#13283)
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
2025-02-06 18:20:03 +08:00
0d13aee15c feat:add deepseek r1 think display for ollama provider (#13272) 2025-02-06 15:32:10 +08:00
49b4144ffd fix: add dataset edit permissions (#13223) 2025-02-06 14:26:16 +08:00
186e2d972e chore(deps): bump katex from 0.16.10 to 0.16.21 in /web (#13270)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-06 13:27:07 +08:00
40dd63ecef Upgrade oracle models (#13174)
Co-authored-by: engchina <atjapan2015@gmail.com>
2025-02-06 13:24:27 +08:00
6d66d6da15 feat(model_providers): Support deepseek-r1 for Nvidia Catalog (#13269)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-06 13:03:19 +08:00
03ec3513f3 Fix bug large data no render (#12683)
Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com>
2025-02-06 13:00:04 +08:00
87763fc234 feat(model_providers): Support deepseek for Azure AI Foundry (#13267)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-06 12:45:48 +08:00
f6c44cae2e feat(model): add gemini-2.0 model (#13266) 2025-02-06 12:28:59 +08:00
xhe
da2ee04fce fix: correct linewrap think display in generic openai api (#13260)
Signed-off-by: xhe <xw897002528@gmail.com>
2025-02-06 10:53:08 +08:00
7673c36af3 feat(model): add gemini-2.0-flash-thinking-exp-01-21 (#13230) 2025-02-06 10:01:00 +08:00
9457b2af2f feat: added models :gemini 2.0 flash 001 and gemini 2.0 pro exp 02-05 (#13247) 2025-02-06 09:58:39 +08:00
7203991032 feat: add parameter "reasoning_effort" and Openai o3-mini (#13243) 2025-02-06 09:29:48 +08:00
xhe
5a685f7156 feat: add think display for volcengine and generic openapi (#13234)
Signed-off-by: xhe <xw897002528@gmail.com>
2025-02-06 09:24:40 +08:00
a6a25030ad fix: updated _position.yaml to include the latest model already integ… (#13245) 2025-02-06 09:21:51 +08:00
00458a31d5 feat: added deepseek r1 and v3 to siliconflow (#13238) 2025-02-05 21:59:18 +08:00
c6ddf6d6cc feat(model_providers): Add Groq DeepSeek-R1-Distill-Llama-70b (#13229)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-05 19:15:29 +08:00
34b21b3065 feat: Add o3-mini and o3-mini-2025-01-31 model variants (#13129)
Co-authored-by: crazywoola <427733928@qq.com>
2025-02-05 17:04:45 +08:00
8fbb355cd2 chore: squash system dependencies installation steps (#13206) 2025-02-05 16:42:53 +08:00
e8b3b7e578 Fix new variables in the conversation opener would override prompt_variables (#13191) 2025-02-05 16:16:00 +08:00
59ca44f493 chore(model_runtime): Move deepseek ahead in the providers list. (#13197)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-05 16:08:28 +08:00
9e1457c2c3 fix: mypy checks violation in AzureBlobStorage (#13215) 2025-02-05 15:56:23 +08:00
fac83e14bc Use DefaultAzureCredential for managed identity in azure blob extention (#11559) 2025-02-05 13:43:43 +08:00
a97cec57e4 fix: SSRF proxy file descriptor leak in concurrent requests (#13108) 2025-02-05 13:10:27 +08:00
38c10b47d3 Feat: add linkedin to readme (#13203) 2025-02-05 12:27:58 +08:00
1a2523fd15 feat: bedrock_endpoint_url (#12838) 2025-02-05 12:24:24 +08:00
03243cb422 Modify params for bedrock retrieve generate (#13182) 2025-02-05 12:17:42 +08:00
2ad7ee0344 chore: add tests for build docker image when dockerfile changed (#10732) 2025-02-05 11:40:22 +08:00
55ce3618ce fix: Dollar Sign Handling in Markdown (#13178)
Co-authored-by: crazywoola <427733928@qq.com>
2025-02-05 11:00:56 +08:00
e9e34c1ab2 Install apt dependencies using bookworm source, consistent with base image. Remove unnecessary, error-prone pins (#13176) 2025-02-05 10:07:22 +08:00
d4c916b496 chore(pyproject): Add type stubs into pyproject.toml (#13145)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-04 12:01:28 +08:00
8fbc9c9342 Solve circular dependency issue between workflow/constants.ts file and default.ts file (#13165) 2025-02-04 09:26:01 +08:00
1b6fd9dfe8 fix: set indexing technique from dataset during update-by-text (#13155) 2025-02-03 11:06:03 +08:00
304467e3f5 fix: not install libmagic raise error (#13146) 2025-02-03 11:05:20 +08:00
7452032d81 add azure openai api version 2024-12-01-preview (#13135) 2025-02-03 11:04:20 +08:00
87e2048f1b nitpick: fix small typos in template.en.mdx (#13156) 2025-02-03 11:03:11 +08:00
d876084392 chore: upgrade libldap2 (#13158) 2025-02-03 11:02:14 +08:00
840729afa5 feat: the think tag display of siliconflow's deepseek r1 (#13153) 2025-02-02 21:55:13 +08:00
941ad03f3c pass model and cost so that langfuse can show cost (#13117) 2025-02-02 15:27:27 +08:00
d73d191f99 feature. add feat to modify metadata via dataset api (#13116) 2025-02-02 15:27:12 +08:00
c2664e0283 chore: fix wrong VectorType match case (#13123) 2025-02-02 15:26:59 +08:00
ee61cede4e test(huggingface_hub): Skip the failed test temporarily. (#13142)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-02 14:47:26 +08:00
b47669b80b fix: deduct LLM quota after processing invoke result (#13075)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-02-02 12:05:11 +08:00
c0d0c63592 feat: switch to chat messages before regenerated (#11301)
Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com>
2025-01-31 13:05:10 +08:00
b09c39c8dc refactor: avoid to use extra space when finding model by name (#13043) 2025-01-30 15:08:29 +08:00
b4b09ddc3c add tongyi qwen2.5-14b/7b-instruct-1m model (#13089) 2025-01-29 11:58:01 +08:00
d0a21086bd refactor: Update Firecrawl API parameters and default settings (#13082) 2025-01-29 11:21:05 +08:00
d44882c1b5 refactor: reduce duplciate code by inheritance (#13073) 2025-01-28 10:52:01 +08:00
23c68efa2d fix: fix the formatter is not applied on log file (#12704) 2025-01-28 10:49:58 +08:00
560c5de1b7 Fixed Novita AI color and added DeepSeek R1 model (#13074) 2025-01-28 10:38:54 +08:00
5d91dbd000 Set default LOG_LEVEL to INFO for celery workers and beat (#13066)
Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com>
2025-01-27 17:09:41 +08:00
6c31ee36cd fix qwen-vl blocking mode (#13052) 2025-01-27 11:35:23 +08:00
edc29780ed fix: "Model schema not found" error only in agents (#12655) (#12760) 2025-01-27 11:33:13 +08:00
aad7e4dd1c fix:Improve MIME type detection for remote URL uploads using python-magic (#12693) 2025-01-27 11:33:03 +08:00
a6a727e8a4 feat: add inner API to create workspace without requiring email (#13021) 2025-01-26 15:36:56 +08:00
d1fc65fabc fix: adjust iteration node dark style (#13051) 2025-01-26 11:19:41 +08:00
d4be5ef9de Update Novita AI predefined models (#13045) 2025-01-26 09:25:29 +08:00
1374be5a31 fix: Unexpected tag creation when pressing enter during tag conversion (#13041) 2025-01-25 19:30:26 +08:00
b2bbc28580 support bedrock kb: retrieve and generate (#13027) 2025-01-25 17:28:06 +08:00
59b3e672aa feat: add agent thinking content display of deepseek R1 (#12949) 2025-01-24 20:13:42 +08:00
a2f8bce8f5 chore: add Japanese translation: model_providers/bedrock (#13016) 2025-01-24 18:43:33 +08:00
a2b9adb3a2 Change typo in translation (#13004) 2025-01-24 13:48:21 +08:00
28067640b5 fix: wrong zh_Hans translation: Ohio (#13006) 2025-01-24 13:41:20 +08:00
da67916843 feat: add glm-4-air-0111 (#12997)
Co-authored-by: lowell <lowell.hu@zkteco.in>
2025-01-24 10:04:46 +08:00
e54ce479ad Feat/prompt editor dark theme (#12976) 2025-01-23 16:20:00 +08:00
6024d8a42d refactor: Update Firecrawl to use v1 API (#12574)
Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com>
2025-01-23 11:14:48 +08:00
f565f08aa0 fix: get property of string type variable caused page crash (#12969) 2025-01-23 11:02:29 +08:00
fd4afe09f8 fix: tools translate search (#12950)
Co-authored-by: lowell <lowell.hu@zkteco.in>
2025-01-22 19:27:02 +08:00
dd0904f95c feat: add giteeAI risk control identification. (#12946) 2025-01-22 19:26:25 +08:00
4c3076f2a4 feat: add pg vector index (#12338)
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
2025-01-22 17:07:18 +08:00
1e73f63ff8 chore: update version to 0.15.2 in packaging and docker configurations (#12940)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-22 16:40:44 +08:00
d167d5b1be feat(ark): support doubao 1.5 series of models (#12935) 2025-01-22 15:25:57 +08:00
71fa14f791 fix: resolve clipboard.writeText failure under HTTP protocol (#12936) 2025-01-22 15:18:23 +08:00
8dd1873e76 feat: workflow note dark theme (#12932) 2025-01-22 14:22:33 +08:00
f91f5c7401 fix(batch_create_segment_to_index_task): count max_position in memory. (#12929) 2025-01-22 13:39:02 +08:00
c62b7cc679 chore(build): bump poetry from 1.x to 2.x (#12369) 2025-01-22 13:38:24 +08:00
3ee213ddca add milvus full text search setting (#12930) 2025-01-22 13:36:39 +08:00
8429877b02 fix: Agent is configured for ReAct inference mode, an error is reported when viewing the agent log (#12920)
Co-authored-by: crazywoola <427733928@qq.com>
2025-01-22 13:20:32 +08:00
05a0faff6a fix: app token's last_used_at can't be updated when last_used_at is null (#12770) 2025-01-22 11:01:45 +08:00
e09f6e4987 feat: support config chunk length by env (#12925) 2025-01-22 10:43:40 +08:00
e23f4b0265 feat: add gemini-2.0-flash-thinking-exp-01-21 (#12924) 2025-01-22 10:14:37 +08:00
f582d4a13e feat: Add ability to change profile avatar (#12642) 2025-01-22 10:11:31 +08:00
2f41bd495d fix:Fix a bug that returns null when the passed path is a file. (#12775)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-01-22 10:10:03 +08:00
162a8c4393 fix update segment keyword with same content (#12908) 2025-01-21 19:19:32 +08:00
3d1ce4c53f bug: fixed bedrock rerank bug (#12774)
Co-authored-by: hobo.l <hobo.l@binance.com>
2025-01-21 19:09:36 +08:00
6db3ae9b8e chore: remove webapp ga (#12909) 2025-01-21 18:38:33 +08:00
6d0cb9dc33 fix: variable panel scrollable (#12769)
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
2025-01-21 17:50:42 +08:00
46e95e8309 fix: OpenAI o1 Bad Request Error (#12839) 2025-01-21 15:29:13 +08:00
a7b9375877 Update deepseek model configuration (#12899) 2025-01-21 15:28:11 +08:00
0c6a8a130e fix: external dataset hit test display issue(#12564) (#12612)
Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com>
2025-01-21 14:31:45 +08:00
9903f1e703 add deepseek-reasoner (#12898) 2025-01-21 12:40:58 +08:00
6fad719e42 chore(fix): Invalid quotes for using Array[String] in HTTP request node as JSON body (#12761) 2025-01-21 10:38:44 +08:00
9aaee8ee47 fix: Issues related to the deletion of conversation_id (#12488) (#12665) 2025-01-21 10:25:35 +08:00
166221d784 chore(lint): fix quotes for f-string formatting by bumping ruff to 0.9.x (#12702) 2025-01-21 10:12:29 +08:00
925d69a2ee feat:Support Minimax-Text-01 (#12763) 2025-01-21 10:08:53 +08:00
5ff08e241a fix: serply credential check query might return empty records (#12784) 2025-01-21 09:38:56 +08:00
3defd24087 feat: allow updating chunk settings for the existing documents (#12833) 2025-01-21 09:25:40 +08:00
9d86147d20 fix: SparkLite API Auth error (#12781) (#12790) 2025-01-20 22:21:21 +08:00
80801ac4ab fix: "parmas" spelling mistake. (#12875) 2025-01-20 22:18:30 +08:00
210926cd91 Fix suggested_question_prompt (#12738) 2025-01-20 22:16:30 +08:00
677a69deed fix(i18n): correct typo in zh-Hant translation (#12852) 2025-01-20 22:15:41 +08:00
8dfdee21ce chore: fix chinese translation for 'recall' (#12772)
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
2025-01-20 22:15:26 +08:00
6ea77ab4cd fix: DeepSeek API Error with response format active (text and json_object) (#12747) 2025-01-20 22:04:18 +08:00
e3c996688d feat: enhance credential extraction logic based on configurate method (#12853) 2025-01-20 21:59:22 +08:00
bc3a570dda fix: Fix rerank model switching issue (#12721)
ok
2025-01-14 15:42:45 +08:00
0800021a2d chore: translate i18n files (#12708)
Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com>
2025-01-14 13:35:23 +08:00
435eddd867 Feat: copyright modification (#12707) 2025-01-14 10:00:57 +08:00
6e0fb055d1 chore: bump version to 0.15.1 (#12690)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-13 19:21:06 +08:00
eux
1e9ac7ffeb feat: add table of contents to Knowledge API doc (#12688) 2025-01-13 18:31:43 +08:00
b4873ecb43 [fix] support feature restore (#12563) 2025-01-13 18:29:06 +08:00
mbo
1859d57784 api tool support multiple env url (#12249)
Co-authored-by: mabo <mabo@aeyes.ai>
2025-01-13 17:49:30 +08:00
69d58fbb50 Add new integration with Opik Tracking tool (#11501) 2025-01-13 17:41:44 +08:00
cb34991663 fix: add type hints for App model and improve error handling in audio services (#12677)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-13 15:55:16 +08:00
c700364e1c fix: Update variable handling in VariableAssignerNode and clean up app_dsl_service (#12672)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-13 15:54:26 +08:00
9a6b1dc3a1 Revert "Feat/new saas billing" (#12673) 2025-01-13 15:17:43 +08:00
54b5b80a07 fix(workflow): fix answer node stream processing in conditional branches (#12510) 2025-01-13 14:54:21 +08:00
831459b895 fix: ruff with statements (#12578)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-01-13 09:55:55 +08:00
4e101604c3 fix: ruff check for True if ... else (#12576)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-01-13 09:38:48 +08:00
a6455269f0 chore: Adjust translations to align with Taiwanese Mandarin conventions (#12633) 2025-01-13 09:12:43 +08:00
cd257b91c5 Fix pandas indexing method for knowledge base imports (#12637) (#12638)
Co-authored-by: CN-P5 <heibai2006@qq.com>
2025-01-13 09:06:59 +08:00
d8f57bf899 Feat/new saas billing (#12591) 2025-01-12 14:50:46 +08:00
989fb11fd7 improve the readability of the function generate_api_key (#12552) 2025-01-09 21:30:17 +08:00
140965b738 chore: translate i18n files (#12543)
Co-authored-by: WTW0313 <30284043+WTW0313@users.noreply.github.com>
2025-01-09 20:30:06 +08:00
14ee51aead Feat/add knowledge include all filter (#12537) 2025-01-09 20:21:25 +08:00
2e97ba5700 fix: Add datasets list access control and fix datasets config display issue (#12533)
Co-authored-by: nite-knite <nkCoding@gmail.com>
2025-01-09 17:44:11 +08:00
f549d53b68 fix: sum costs return error value on overview page (#12534) 2025-01-09 16:04:14 +08:00
a085ad4719 feat: show workflow running status (#12531) 2025-01-09 15:36:13 +08:00
f230a9232e fix: Parsing OpenAPI spec for external tools (#12518) (#12530) 2025-01-09 15:30:43 +08:00
e84bf35e2a fix: same chunk insert deadlock (#12502)
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
2025-01-09 15:16:41 +08:00
eux
20f090537f feat: add GET upload file API endpoint to dataset service api (#11899) 2025-01-09 14:52:09 +08:00
dbe7a7c4fd Fix: Add a INFO-level log when fallback to gpt2tokenizer (#12508) 2025-01-09 14:37:46 +08:00
b7a4e3903e fix: add last_refresh_time to track the validity of is_other_tab_refreshing (#12517) 2025-01-09 10:40:45 +08:00
3476 changed files with 290972 additions and 71535 deletions

View File

@ -1,12 +1,11 @@
#!/bin/bash
npm add -g pnpm@9.12.2
cd web && pnpm install
cd web && npm install
pipx install poetry
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc

View File

@ -4,7 +4,6 @@ on:
pull_request:
branches:
- main
- plugins/beta
paths:
- api/**
- docker/**
@ -48,9 +47,15 @@ jobs:
- name: Run Unit tests
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
- name: Run ModelRuntime
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
- name: Run dify config tests
run: poetry run -P api python dev/pytest/pytest_config_tests.py
- name: Run Tool
run: poetry run -P api bash dev/pytest/pytest_tools.sh
- name: Run mypy
run: |
poetry run -C api python -m mypy --install-types --non-interactive .

View File

@ -5,8 +5,7 @@ on:
branches:
- "main"
- "deploy/dev"
- "plugins/beta"
- "dev/plugin-deploy"
- "e-260"
release:
types: [published]
@ -141,16 +140,3 @@ jobs:
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env[matrix.image_name_env] }}:${{ steps.meta.outputs.version }}
- name: print context var
uses: actions/checkout@v4
- name: deploy pod in plugin env
if: github.ref == 'refs/heads/dev/plugin-deploy'
env:
IMAGEHASH: ${{ github.sha }}
APICMD: "${{ secrets.PLUGIN_CD_API_CURL }}"
WEBCMD: "${{ secrets.PLUGIN_CD_WEB_CURL }}"
run: |
bash -c "${APICMD/yourNewVersion/$IMAGEHASH}"
bash -c "${WEBCMD/yourNewVersion/$IMAGEHASH}"

View File

@ -4,7 +4,6 @@ on:
pull_request:
branches:
- main
- plugins/beta
paths:
- api/migrations/**
- .github/workflows/db-migration-test.yml

29
.github/workflows/deploy-enterprise.yml vendored Normal file
View File

@ -0,0 +1,29 @@
name: Deploy Enterprise
permissions:
contents: read
on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/enterprise"
types:
- completed
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/enterprise'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.ENTERPRISE_SSH_HOST }}
username: ${{ secrets.ENTERPRISE_SSH_USER }}
password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }}
script: |
${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }}

View File

@ -1,23 +0,0 @@
name: Deploy Plugin Dev
on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "dev/plugin-deploy"
types:
- completed
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: "echo 123"

View File

@ -4,7 +4,6 @@ on:
pull_request:
branches:
- main
- plugins/beta
concurrency:
group: style-${{ github.head_ref || github.run_id }}
@ -67,23 +66,17 @@ jobs:
with:
files: web/**
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 10
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 20
cache: pnpm
cache: yarn
cache-dependency-path: ./web/package.json
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: pnpm install --frozen-lockfile
run: yarn install --frozen-lockfile
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
@ -141,7 +134,7 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning
DEFAULT_BRANCH: plugins/beta
DEFAULT_BRANCH: main
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true

View File

@ -32,10 +32,10 @@ jobs:
with:
node-version: ${{ matrix.node-version }}
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'
cache-dependency-path: 'yarn.lock'
- name: Install Dependencies
run: pnpm install
run: yarn install
- name: Test
run: pnpm test
run: yarn test

View File

@ -38,11 +38,11 @@ jobs:
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: pnpm install --frozen-lockfile
run: yarn install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: pnpm run auto-gen-i18n
run: npm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'

View File

@ -34,13 +34,13 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 20
cache: pnpm
cache: yarn
cache-dependency-path: ./web/package.json
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: pnpm install --frozen-lockfile
run: yarn install --frozen-lockfile
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
run: pnpm test
run: yarn test

7
.gitignore vendored
View File

@ -175,7 +175,6 @@ docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/*
docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/*
!docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf
@ -194,9 +193,3 @@ api/.vscode
.idea/
.vscode
# pnpm
/.pnpm-store
# plugin migrate
plugins.jsonl

4
.markdownlint.json Normal file
View File

@ -0,0 +1,4 @@
{
"MD024": false,
"MD013": false
}

45
CHANGELOG.md Normal file
View File

@ -0,0 +1,45 @@
# Changelog
All notable changes to Dify will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.15.8] - 2025-05-30
### Added
- Added gunicorn keepalive setting (#19537)
### Fixed
- Fixed database configuration to allow DB_EXTRAS to set search_path via options (#16a4f77)
- Fixed frontend third-party package security issues (#19655)
- Updated dependencies: huggingface-hub (~0.16.4 to ~0.31.0), transformers (~4.35.0 to ~4.39.0), and resend (~0.7.0 to ~2.9.0) (#19563)
- Downgrade boto3 from 1.36 to 1.35 (#19736)
## [0.15.7] - 2025-04-27
### Added
- Added support for GPT-4.1 in model providers (#18912)
- Added support for Amazon Bedrock DeepSeek-R1 model (#18908)
- Added support for Amazon Bedrock Claude Sonnet 3.7 model (#18788)
- Refined version compatibility logic in app DSL service
### Fixed
- Fixed issue with creating apps from template categories (#18807, #18868)
- Fixed DSL version check when creating apps from explore templates (#18872, #18878)
## [0.15.6] - 2025-04-22
### Security
- Fixed clickjacking vulnerability (#18552)
- Fixed reset password security issue (#18366)
- Updated reset password token when email code verification succeeds (#18362)
### Fixed
- Fixed Vertex AI Gemini 2.0 Flash 001 schema (#18405)

View File

@ -1,10 +1,7 @@
.env
*.env.*
storage/generate_files/*
storage/privkeys/*
storage/tools/*
storage/upload_files/*
# Logs
logs
@ -12,8 +9,6 @@ logs
# jetbrains
.idea
.mypy_cache
.ruff_cache
# venv
.venv

View File

@ -409,6 +409,7 @@ MAX_VARIABLE_SIZE=204800
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@ -421,22 +422,6 @@ POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=
# Plugin configuration
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
@ -445,4 +430,7 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400
LOGIN_LOCKOUT_DURATION=86400
# Prevent Clickjacking
ALLOW_EMBED=false

View File

@ -55,7 +55,7 @@ RUN \
# basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# For Security
# expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install a chinese font to support the use of tools like matplotlib
fonts-noto-cjk \
# install libmagic to support the use of python-magic guess MIMETYPE
@ -71,10 +71,6 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
# Copy source code
COPY . /app/api/

View File

@ -25,8 +25,6 @@ from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
@click.command("reset-password", help="Reset the account password.")
@ -526,7 +524,7 @@ def add_qdrant_doc_id_index(field: str):
)
)
except Exception:
except Exception as e:
click.echo(click.style("Failed to create Qdrant client.", fg="red"))
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
@ -595,7 +593,7 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green"))
except Exception:
except Exception as e:
logging.exception("Failed to execute database migration")
finally:
lock.release()
@ -641,7 +639,7 @@ where sites.id is null limit 1000"""
account = accounts[0]
print("Fixing missing site for app {}".format(app.id))
app_was_created.send(app, account=account)
except Exception:
except Exception as e:
failed_app_ids.append(app_id)
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
@ -651,68 +649,3 @@ where sites.id is null limit 1000"""
break
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
def migrate_data_for_plugin():
"""
Migrate data for plugin.
"""
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
PluginDataMigration.migrate()
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
@click.command("extract-plugins", help="Extract plugins.")
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
def extract_plugins(output_file: str, workers: int):
"""
Extract plugins.
"""
click.echo(click.style("Starting extract plugins.", fg="white"))
PluginMigration.extract_plugins(output_file, workers)
click.echo(click.style("Extract plugins completed.", fg="green"))
@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
@click.option(
"--output_file",
prompt=True,
help="The file to store the extracted unique identifiers.",
default="unique_identifiers.json",
)
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
def extract_unique_plugins(output_file: str, input_file: str):
"""
Extract unique plugins.
"""
click.echo(click.style("Starting extract unique plugins.", fg="white"))
PluginMigration.extract_unique_plugins_to_file(input_file, output_file)
click.echo(click.style("Extract unique plugins completed.", fg="green"))
@click.command("install-plugins", help="Install plugins.")
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
@click.option(
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
)
def install_plugins(input_file: str, output_file: str):
"""
Install plugins.
"""
click.echo(click.style("Starting install plugins.", fg="white"))
PluginMigration.install_plugins(input_file, output_file)
click.echo(click.style("Install plugins completed.", fg="green"))

View File

@ -134,60 +134,6 @@ class CodeExecutionSandboxConfig(BaseSettings):
)
class PluginConfig(BaseSettings):
"""
Plugin configs
"""
PLUGIN_DAEMON_URL: HttpUrl = Field(
description="Plugin API URL",
default="http://localhost:5002",
)
PLUGIN_DAEMON_KEY: str = Field(
description="Plugin API key",
default="plugin-api-key",
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
description="Plugin Remote Install Host",
default="localhost",
)
PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field(
description="Plugin Remote Install Port",
default=5003,
)
PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin packages in bytes",
default=15728640,
)
PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin bundles in bytes",
default=15728640 * 12,
)
class MarketplaceConfig(BaseSettings):
"""
Configuration for marketplace
"""
MARKETPLACE_ENABLED: bool = Field(
description="Enable or disable marketplace",
default=True,
)
MARKETPLACE_API_URL: HttpUrl = Field(
description="Marketplace API URL",
default="https://marketplace.dify.ai",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@ -214,10 +160,6 @@ class EndpointConfig(BaseSettings):
default="",
)
ENDPOINT_URL_TEMPLATE: str = Field(
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
)
class FileAccessConfig(BaseSettings):
"""
@ -556,11 +498,6 @@ class AuthConfig(BaseSettings):
default=86400,
)
FORGOT_PASSWORD_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying password reset after exceeding the rate limit.",
default=86400,
)
class ModerationConfig(BaseSettings):
"""
@ -851,8 +788,6 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig,
EndpointConfig,
FileAccessConfig,

View File

@ -1,5 +1,5 @@
from typing import Any, Literal, Optional
from urllib.parse import quote_plus
from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
@ -166,14 +166,28 @@ class DatabaseConfig(BaseSettings):
default=False,
)
@computed_field
@computed_field # type: ignore[misc]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
# Always include timezone
timezone_opt = "-c timezone=UTC"
if options:
# Merge user options and timezone
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": {"options": "-c timezone=UTC"},
"connect_args": connect_args,
}

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="1.0.0",
default="0.15.8",
)
COMMIT_SHA: str = Field(

View File

@ -1,19 +1,9 @@
from contextvars import ContextVar
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")

View File

@ -2,7 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .app.app_import import AppImportApi, AppImportConfirmApi
from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import (
@ -40,7 +40,6 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version
@ -167,15 +166,4 @@ api.add_resource(
from .tag import tags
# Import workspace controllers
from .workspace import (
account,
agent_providers,
endpoint,
load_balancing_config,
members,
model_providers,
models,
plugin,
tool_providers,
workspace,
)
from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace

View File

@ -2,8 +2,6 @@ from functools import wraps
from flask import request
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
@ -56,8 +54,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
app = App.query.filter(App.id == args["app_id"]).first()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
@ -73,10 +70,7 @@ class InsertExploreAppListApi(Resource):
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
if not recommended_app:
recommended_app = RecommendedApp(
@ -116,27 +110,17 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
if not recommended_app:
return {"result": "success"}, 204
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
app = App.query.filter(App.id == recommended_app.app_id).first()
if app:
app.is_public = False
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).filter(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
).all()
installed_apps = InstalledApp.query.filter(
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
).all()
for installed_app in installed_apps:
db.session.delete(installed_app)

View File

@ -3,8 +3,6 @@ from typing import Any
import flask_restful # type: ignore
from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
@ -28,16 +26,7 @@ api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="it
def _get_resource(resource_id, tenant_id, resource_model):
if resource_model == App:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
else:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
if resource is None:
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")

View File

@ -2,30 +2,28 @@ import uuid
from typing import cast
from flask_login import current_user # type: ignore
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore
from flask_restful import (Resource, inputs, marshal, # type: ignore
marshal_with, reqparse)
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
enterprise_license_required,
setup_required,
)
from controllers.console.wraps import (account_initialization_required,
cloud_edition_billing_resource_check,
enterprise_license_required,
setup_required)
from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
app_pagination_fields,
)
from fields.app_fields import (app_detail_fields, app_detail_fields_with_site,
app_pagination_fields)
from libs.login import login_required
from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
@ -67,7 +65,17 @@ class AppListApi(Resource):
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
return marshal(app_pagination, app_pagination_fields)
if FeatureService.get_system_features().webapp_auth.enabled:
app_ids = [str(app.id) for app in app_pagination.items]
res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids)
if len(res) != len(app_ids):
raise BadRequest("Invalid app id in webapp auth")
for app in app_pagination.items:
if str(app.id) in res:
app.access_mode = res[str(app.id)].access_mode
return marshal(app_pagination, app_pagination_fields), 200
@setup_required
@login_required
@ -111,6 +119,10 @@ class AppApi(Resource):
app_model = app_service.get_app(app_model)
if FeatureService.get_system_features().webapp_auth.enabled:
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
return app_model
@setup_required

View File

@ -5,17 +5,17 @@ from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from fields.app_fields import app_import_fields
from libs.login import login_required
from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
class AppImportApi(Resource):
@ -58,7 +58,9 @@ class AppImportApi(Resource):
app_id=args.get("app_id"),
)
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
@ -90,20 +92,3 @@ class AppImportConfirmApi(Resource):
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class AppImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)
return result.model_dump(mode="json"), 200

View File

@ -2,7 +2,6 @@ from datetime import UTC, datetime
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language
@ -51,37 +50,33 @@ class AppSite(Resource):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
site = session.query(Site).filter(Site.app_id == app_model.id).first()
site = Site.query.filter(Site.app_id == app_model.id).one_or_404()
if not site:
raise NotFound
for attr_name in [
"title",
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)
for attr_name in [
"title",
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)
site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
session.commit()
site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return site

View File

@ -20,7 +20,6 @@ from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models import App
from models.account import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
@ -97,9 +96,6 @@ class DraftWorkflowApi(Resource):
else:
abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
try:
@ -143,9 +139,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="")
@ -167,7 +160,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception:
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
@ -185,9 +178,6 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -204,7 +194,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception:
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
@ -222,9 +212,6 @@ class WorkflowDraftRunIterationNodeApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -241,7 +228,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception:
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
@ -259,9 +246,6 @@ class DraftWorkflowRunApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
@ -310,20 +294,13 @@ class DraftWorkflowNodeRunApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
)
return workflow_node_execution
@ -362,9 +339,6 @@ class PublishedWorkflowApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
@ -402,17 +376,12 @@ class DefaultBlockConfigApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
args = parser.parse_args()
q = args.get("q")
filters = None
if q:
if args.get("q"):
try:
filters = json.loads(args.get("q", ""))
except json.JSONDecodeError:
@ -438,9 +407,6 @@ class ConvertToWorkflowApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
if request.data:
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json")

View File

@ -3,20 +3,16 @@ import secrets
from flask import request
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import setup_required
from controllers.console.auth.error import (EmailCodeError, InvalidEmailError,
InvalidTokenError,
PasswordMismatchError)
from controllers.console.error import (AccountInFreezeError, AccountNotFound,
EmailSendIpLimitError)
from controllers.console.wraps import (email_password_login_enabled,
setup_required)
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
@ -24,12 +20,14 @@ from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService, TenantService
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.errors.workspace import (WorkSpaceNotAllowedCreateError,
WorkspacesLimitExceededError)
from services.feature_service import FeatureService
class ForgotPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
@ -45,8 +43,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = Account.query.filter_by(email=args["email"]).first()
token = None
if account is None:
if FeatureService.get_system_features().is_allow_register:
@ -62,6 +59,7 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
@ -71,10 +69,6 @@ class ForgotPasswordCheckApi(Resource):
user_email = args["email"]
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
if token_data is None:
raise InvalidTokenError()
@ -83,15 +77,22 @@ class ForgotPasswordCheckApi(Resource):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
raise EmailCodeError()
AccountService.reset_forgot_password_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email")}
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
class ForgotPasswordResetApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
@ -110,6 +111,9 @@ class ForgotPasswordResetApi(Resource):
if reset_data is None:
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
AccountService.revoke_reset_password_token(token)
@ -119,8 +123,7 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
account = Account.query.filter_by(email=reset_data.get("email")).first()
if account:
account.password = base64_password_hashed
account.password_salt = base64_salt
@ -141,8 +144,10 @@ class ForgotPasswordResetApi(Resource):
)
except WorkSpaceNotAllowedCreateError:
pass
except AccountRegisterError:
except AccountRegisterError as are:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
pass
return {"result": "success"}

View File

@ -21,8 +21,9 @@ from controllers.console.error import (
AccountNotFound,
EmailSendIpLimitError,
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from controllers.console.wraps import setup_required
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
@ -30,7 +31,7 @@ from models.account import Account
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
@ -38,6 +39,7 @@ class LoginApi(Resource):
"""Resource for user login."""
@setup_required
@email_password_login_enabled
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
@ -87,10 +89,15 @@ class LoginApi(Resource):
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
return {
"result": "fail",
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
system_features = FeatureService.get_system_features()
if system_features.is_allow_create_workspace and not system_features.license.workspaces.is_available():
raise WorkspacesLimitExceeded()
else:
return {
"result": "fail",
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
@ -110,6 +117,7 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
@ -196,6 +204,9 @@ class EmailCodeLoginApi(Resource):
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
workspaces = FeatureService.get_system_features().license.workspaces
if not workspaces.is_available():
raise WorkspacesLimitExceeded()
if not FeatureService.get_system_features().is_allow_create_workspace:
raise NotAllowedCreateWorkspace()
else:
@ -213,6 +224,8 @@ class EmailCodeLoginApi(Resource):
return NotAllowedCreateWorkspace()
except AccountRegisterError as are:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -5,8 +5,6 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_restful import Resource # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@ -137,8 +135,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Optional[Account] = Account.get_by_openid(provider, user_info.id)
if not account:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
account = Account.query.filter_by(email=user_info.email).first()
return account

View File

@ -4,8 +4,6 @@ import json
from flask import request
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console import api
@ -78,10 +76,7 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
if data_source_binding is None:
raise NotFound("Data source binding not found.")
# enable binding
@ -113,53 +108,47 @@ class DataSourceNotionListApi(Resource):
def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str)
exist_page_ids = []
with Session(db.engine) as session:
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
documents = session.execute(
select(Document).filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = session.scalars(
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
)
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
documents = Document.query.filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
if not data_source_bindings:
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page["is_bound"] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = DataSourceOauthBinding.query.filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
).all()
if not data_source_bindings:
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page["is_bound"] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
class DataSourceNotionApi(Resource):
@ -169,17 +158,14 @@ class DataSourceNotionApi(Resource):
def get(self, workspace_id, page_id, page_type):
workspace_id = str(workspace_id)
page_id = str(page_id)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).scalar_one_or_none()
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
).first()
if not data_source_binding:
raise NotFound("Data source binding not found.")

View File

@ -620,7 +620,6 @@ class DatasetRetrievalSettingApi(Resource):
match vector_type:
case (
VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT

View File

@ -7,6 +7,7 @@ from flask import request
from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
from sqlalchemy import asc, desc
from transformers.hf_argparser import string_to_bool # type: ignore
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -39,7 +40,6 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -150,20 +150,8 @@ class DatasetDocumentListApi(Resource):
sort = request.args.get("sort", default="-created_at", type=str)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch_val = request.args.get("fetch", default="false")
if isinstance(fetch_val, bool):
fetch = fetch_val
else:
if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
fetch = True
elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
fetch = False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
f"(case insensitive)."
)
except (ArgumentTypeError, ValueError, Exception):
fetch = string_to_bool(request.args.get("fetch", default="false"))
except (ArgumentTypeError, ValueError, Exception) as e:
fetch = False
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
@ -322,7 +310,7 @@ class DatasetInitApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
parser = reqparse.RequestParser()
@ -441,8 +429,6 @@ class DocumentIndexingEstimateApi(DocumentResource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except PluginDaemonClientSideError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e:
raise IndexingEstimateError(str(e))
@ -543,8 +529,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except PluginDaemonClientSideError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e:
raise IndexingEstimateError(str(e))
@ -700,7 +684,7 @@ class DocumentProcessingApi(DocumentResource):
document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
if action == "pause":
@ -764,7 +748,7 @@ class DocumentMetadataApi(DocumentResource):
doc_metadata = req_data.get("doc_metadata")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
if doc_type is None or doc_metadata is None:

View File

@ -122,7 +122,7 @@ class DatasetDocumentSegmentListApi(Resource):
segment_ids = request.args.getlist("segment_id")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -149,7 +149,7 @@ class DatasetDocumentSegmentApi(Resource):
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
try:
@ -202,7 +202,7 @@ class DatasetDocumentSegmentAddApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
@ -277,7 +277,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -320,7 +320,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -420,7 +420,7 @@ class ChildChunkAddApi(Resource):
).first()
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
@ -520,7 +520,7 @@ class ChildChunkAddApi(Resource):
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -570,7 +570,7 @@ class ChildChunkUpdateApi(Resource):
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
@ -614,7 +614,7 @@ class ChildChunkUpdateApi(Resource):
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)

View File

@ -46,6 +46,18 @@ class NotAllowedCreateWorkspace(BaseHTTPException):
code = 400
class WorkspaceMembersLimitExceeded(BaseHTTPException):
error_code = "limit_exceeded"
description = "Unable to add member because the maximum workspace's member limit was exceeded"
code = 400
class WorkspacesLimitExceeded(BaseHTTPException):
error_code = "limit_exceeded"
description = "Unable to create workspace because the maximum workspace limit was exceeded"
code = 400
class AccountBannedError(BaseHTTPException):
error_code = "account_banned"
description = "Account is banned."

View File

@ -23,3 +23,9 @@ class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
error_code = "app_suggested_questions_after_answer_disabled"
description = "Function Suggested questions after answer disabled."
code = 403
class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403

View File

@ -1,20 +1,26 @@
import logging
from datetime import UTC, datetime
from typing import Any
from flask import request
from flask_login import current_user # type: ignore
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
from flask_restful import (Resource, inputs, marshal_with, # type: ignore
reqparse)
from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (account_initialization_required,
cloud_edition_billing_resource_check)
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.login import login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
class InstalledAppsListApi(Resource):
@ -48,6 +54,30 @@ class InstalledAppsListApi(Resource):
for installed_app in installed_apps
if installed_app.app is not None
]
# filter out apps that user doesn't have access to
if FeatureService.get_system_features().webapp_auth.enabled:
user_id = current_user.id
res = []
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
for installed_app in installed_app_list:
webapp_setting = webapp_settings.get(installed_app["app"].id)
if not webapp_setting:
continue
if webapp_setting.access_mode == "sso_verified":
continue
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_code=app_code,
):
res.append(installed_app)
installed_app_list = res
logging.info(
f"installed_app_list: {installed_app_list}, user_id: {user_id}"
)
installed_app_list.sort(
key=lambda app: (
-app["is_pinned"],

View File

@ -4,10 +4,14 @@ from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from models import InstalledApp
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
def installed_app_required(view=None):
@ -48,6 +52,30 @@ def installed_app_required(view=None):
return decorator
def user_allowed_to_access_app(view=None):
def decorator(view):
@wraps(view)
def decorated(installed_app: InstalledApp, *args, **kwargs):
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=str(current_user.id),
app_code=app_code,
)
if not res:
raise AppAccessDeniedError()
return view(installed_app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [installed_app_required, account_initialization_required, login_required]
method_decorators = [user_allowed_to_access_app, installed_app_required, account_initialization_required, login_required]

View File

@ -2,11 +2,8 @@ import os
from flask import session
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from libs.helper import StrLen
from models.model import DifySetup
from services.account_service import TenantService
@ -45,11 +42,7 @@ class InitValidateAPI(Resource):
def get_init_validate_status():
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return session.get("is_init_validated") or DifySetup.query.first()
return True

View File

@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup, db
from models.model import DifySetup
from services.account_service import RegisterService, TenantService
from . import api
@ -52,9 +52,8 @@ class SetupApi(Resource):
def get_setup_status():
if dify_config.EDITION == "SELF_HOSTED":
return db.session.query(DifySetup).first()
else:
return True
return DifySetup.query.first()
return True
api.add_resource(SetupApi, "/setup")

View File

@ -1,56 +0,0 @@
from functools import wraps
from flask_login import current_user # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from models.account import TenantPluginPermission
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
user = current_user
tenant_id = user.current_tenant_id
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
.filter(
TenantPluginPermission.tenant_id == tenant_id,
)
.first()
)
if not permission:
# no permission set, allow access for everyone
return view(*args, **kwargs)
if install_required:
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
pass
if debug_required:
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
pass
return view(*args, **kwargs)
return decorated
return interceptor

View File

@ -1,36 +0,0 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.agent_service import AgentService
class AgentProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
class AgentProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")

View File

@ -1,205 +0,0 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.plugin.endpoint_service import EndpointService
class EndpointCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"]
settings = args["settings"]
name = args["name"]
return {
"success": EndpointService.create_endpoint(
tenant_id=user.current_tenant_id,
user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
settings=settings,
)
}
class EndpointListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
args = parser.parse_args()
page = args["page"]
page_size = args["page_size"]
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints(
tenant_id=user.current_tenant_id,
user_id=user.id,
page=page,
page_size=page_size,
)
}
)
class EndpointListForSinglePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
parser.add_argument("plugin_id", type=str, required=True, location="args")
args = parser.parse_args()
page = args["page"]
page_size = args["page_size"]
plugin_id = args["plugin_id"]
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=user.current_tenant_id,
user_id=user.id,
plugin_id=plugin_id,
page=page,
page_size=page_size,
)
}
)
class EndpointDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
if not user.is_admin_or_owner:
raise Forbidden()
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.delete_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
class EndpointUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
settings = args["settings"]
name = args["name"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.update_endpoint(
tenant_id=user.current_tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=name,
settings=settings,
)
}
class EndpointEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.enable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
class EndpointDisableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.disable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")

View File

@ -112,10 +112,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
# Load Balancing Config
api.add_resource(
LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
)
api.add_resource(
LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
)

View File

@ -6,6 +6,7 @@ from flask_restful import Resource, abort, marshal_with, reqparse # type: ignor
import services
from configs import dify_config
from controllers.console import api
from controllers.console.error import WorkspaceMembersLimitExceeded
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@ -17,6 +18,7 @@ from libs.login import login_required
from models.account import Account, TenantAccountRole
from services.account_service import RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
class MemberListApi(Resource):
@ -54,6 +56,12 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
if not workspace_members.is_available(len(invitee_emails)):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
try:
token = RegisterService.invite_new_member(
@ -71,7 +79,6 @@ class MemberInviteEmailApi(Resource):
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
break
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})

View File

@ -79,7 +79,7 @@ class ModelProviderValidateApi(Resource):
response = {"result": "success" if result else "error"}
if not result:
response["error"] = error or "Unknown error"
response["error"] = error
return response
@ -125,10 +125,9 @@ class ModelProviderIconApi(Resource):
Get model provider icon
"""
def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
def get(self, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
tenant_id=tenant_id,
provider=provider,
icon_type=icon_type,
lang=lang,
@ -184,17 +183,53 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
return data
class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)
return result
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=False, nullable=True, location="args")
args = parser.parse_args()
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
)
return result
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials")
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
)
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type"
)
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource(
ModelProviderIconApi,
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url"
)
api.add_resource(
ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
)
api.add_resource(
ModelProviderFreeQuotaQualificationVerifyApi,
"/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
)

View File

@ -325,7 +325,7 @@ class ModelProviderModelValidateApi(Resource):
response = {"result": "success" if result else "error"}
if not result:
response["error"] = error or ""
response["error"] = error
return response
@ -362,26 +362,26 @@ class ModelProviderAvailableModelApi(Resource):
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
api.add_resource(
ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<path:provider>/models/enable",
"/workspaces/current/model-providers/<string:provider>/models/enable",
endpoint="model-provider-model-enable",
)
api.add_resource(
ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<path:provider>/models/disable",
"/workspaces/current/model-providers/<string:provider>/models/disable",
endpoint="model-provider-model-disable",
)
api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
)
api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
)
api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
)
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@ -1,475 +0,0 @@
import io
from flask import request, send_file
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.manager.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
try:
return {
"key": PluginService.get_debugging_key(tenant_id),
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
try:
plugins = PluginService.list(tenant_id)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
class PluginListInstallationsFromIdsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args()
try:
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
class PluginIconApi(Resource):
@setup_required
def get(self):
req = reqparse.RequestParser()
req.add_argument("tenant_id", type=str, required=True, location="args")
req.add_argument("filename", type=str, required=True, location="args")
args = req.parse_args()
try:
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
class PluginUploadFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
file = request.files["pkg"]
# check file size
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_pkg(tenant_id, content)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginUploadFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
args = parser.parse_args()
try:
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginUploadFromBundleApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
file = request.files["bundle"]
# check file size
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_bundle(tenant_id, content)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginInstallFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
try:
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginInstallFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
args = parser.parse_args()
try:
response = PluginService.install_from_github(
tenant_id,
args["plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginInstallFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
try:
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
args = parser.parse_args()
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_plugin_manifest(
tenant_id, args["plugin_unique_identifier"]
).model_dump()
}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
args = parser.parse_args()
try:
return jsonable_encoder(
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_marketplace(
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_github(
tenant_id,
args["original_plugin_unique_identifier"],
args["new_plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUninstallApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginChangePermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser()
req.add_argument("install_permission", type=str, required=True, location="json")
req.add_argument("debug_permission", type=str, required=True, location="json")
args = req.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
class PluginFetchPermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
return jsonable_encoder(
{
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
)
return jsonable_encoder(
{
"install_permission": permission.install_permission,
"debug_permission": permission.debug_permission,
}
)
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

View File

@ -25,10 +25,8 @@ class ToolProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument(
@ -49,43 +47,28 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools(
user_id,
tenant_id,
provider,
)
)
class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
if not user.is_admin_or_owner:
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return BuiltinToolManageService.delete_builtin_tool_provider(
user_id,
@ -99,13 +82,11 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
if not user.is_admin_or_owner:
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -150,13 +131,11 @@ class ToolApiProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -189,11 +168,6 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
@ -201,8 +175,8 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
args = parser.parse_args()
return ApiToolManageService.get_api_tool_provider_remote_schema(
user_id,
tenant_id,
current_user.id,
current_user.current_tenant_id,
args["url"],
)
@ -212,10 +186,8 @@ class ToolApiProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
@ -237,13 +209,11 @@ class ToolApiProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -278,13 +248,11 @@ class ToolApiProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
@ -304,10 +272,8 @@ class ToolApiProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
@ -327,11 +293,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
class ToolApiProviderSchemaApi(Resource):
@ -382,13 +344,11 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
@ -421,13 +381,11 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -463,13 +421,11 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -488,10 +444,8 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
@ -522,10 +476,8 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
@ -546,10 +498,8 @@ class ToolBuiltinListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder(
[
@ -567,10 +517,8 @@ class ToolApiListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder(
[
@ -588,10 +536,8 @@ class ToolWorkflowListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder(
[
@ -617,18 +563,16 @@ class ToolLabelsApi(Resource):
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
# api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")

View File

@ -7,12 +7,12 @@ from flask_login import current_user # type: ignore
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
from .error import (NotInitValidateError, NotSetupError,
UnauthorizedAndForceLogout)
def account_initialization_required(view):
@ -40,6 +40,28 @@ def only_edition_cloud(view):
return decorated
def only_edition_enterprise(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.ENTERPRISE_ENABLED:
abort(404)
return view(*args, **kwargs)
return decorated
def only_edition_self_hosted(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.ENTERPRISE_ENABLED:
abort(404)
return view(*args, **kwargs)
return decorated
def only_edition_self_hosted(view):
@wraps(view)
def decorated(*args, **kwargs):
@ -135,13 +157,9 @@ def setup_required(view):
@wraps(view)
def decorated(*args, **kwargs):
# check setup
if (
dify_config.EDITION == "SELF_HOSTED"
and os.environ.get("INIT_PASSWORD")
and not db.session.query(DifySetup).first()
):
if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first():
raise NotInitValidateError()
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first():
raise NotSetupError()
return view(*args, **kwargs)
@ -159,3 +177,16 @@ def enterprise_license_required(view):
return view(*args, **kwargs)
return decorated
def email_password_login_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if features.enable_email_password_login:
return view(*args, **kwargs)
# otherwise, return 403
abort(403)
return decorated

View File

@ -6,4 +6,4 @@ bp = Blueprint("files", __name__)
api = ExternalApi(bp)
from . import image_preview, tool_files, upload
from . import image_preview, tool_files

View File

@ -1,69 +0,0 @@
from flask import request
from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden
import services
from controllers.console.wraps import setup_required
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from controllers.inner_api.plugin.wraps import get_user
from controllers.service_api.app.error import FileTooLargeError
from core.file.helpers import verify_plugin_file_signature
from fields.file_fields import file_fields
from services.file_service import FileService
class PluginUploadFileApi(Resource):
@setup_required
@marshal_with(file_fields)
def post(self):
# get file from request
file = request.files["file"]
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
sign = request.args.get("sign")
tenant_id = request.args.get("tenant_id")
if not tenant_id:
raise Forbidden("Invalid request.")
user_id = request.args.get("user_id")
user = get_user(tenant_id, user_id)
filename = file.filename
mimetype = file.mimetype
if not filename or not mimetype:
raise Forbidden("Invalid request.")
if not timestamp or not nonce or not sign:
raise Forbidden("Invalid request.")
if not verify_plugin_file_signature(
filename=filename,
mimetype=mimetype,
tenant_id=tenant_id,
user_id=user_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
):
raise Forbidden("Invalid request.")
try:
upload_file = FileService.upload_file(
filename=filename,
content=file.read(),
mimetype=mimetype,
user=user,
source=None,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return upload_file, 201
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")

View File

@ -5,5 +5,5 @@ from libs.external_api import ExternalApi
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
api = ExternalApi(bp)
from .plugin import plugin
from . import mail
from .workspace import workspace

View File

@ -0,0 +1,27 @@
from flask_restful import (
Resource, # type: ignore
reqparse,
)
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.wraps import inner_api_only
from services.enterprise.mail_service import DifyMail, EnterpriseMailService
class EnterpriseMail(Resource):
@setup_required
@inner_api_only
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("to", type=str, action="append", required=True)
parser.add_argument("subject", type=str, required=True)
parser.add_argument("body", type=str, required=True)
parser.add_argument("substitutions", type=dict, required=False)
args = parser.parse_args()
EnterpriseMailService.send_mail(DifyMail(**args))
return {"message": "success"}, 200
api.add_resource(EnterpriseMail, "/enterprise/mail")

View File

@ -1,293 +0,0 @@
from flask_restful import Resource # type: ignore
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only
from core.file.helpers import get_signed_file_url_for_plugin
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
from core.plugin.backwards_invocation.encrypt import PluginEncrypter
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeParameterExtractorNode,
RequestInvokeQuestionClassifierNode,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeSummary,
RequestInvokeTextEmbedding,
RequestInvokeTool,
RequestInvokeTTS,
RequestRequestUploadFile,
)
from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import compact_generate_response
from models.account import Account, Tenant
from models.model import EndUser
class PluginInvokeLLMApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLM)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
def generator():
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return compact_generate_response(generator())
class PluginInvokeTextEmbeddingApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTextEmbedding)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_text_embedding(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeRerankApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeRerank)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_rerank(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeTTSApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTTS)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
def generator():
response = PluginModelBackwardsInvocation.invoke_tts(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return compact_generate_response(generator())
class PluginInvokeSpeech2TextApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSpeech2Text)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_speech2text(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeModerationApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeModeration)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_moderation(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeToolApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTool)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
def generator():
return PluginToolBackwardsInvocation.convert_to_event_stream(
PluginToolBackwardsInvocation.invoke_tool(
tenant_id=tenant_model.id,
user_id=user_model.id,
tool_type=ToolProviderType.value_of(payload.tool_type),
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
),
)
return compact_generate_response(generator())
class PluginInvokeParameterExtractorNodeApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginNodeBackwardsInvocation.invoke_parameter_extractor(
tenant_id=tenant_model.id,
user_id=user_model.id,
parameters=payload.parameters,
model_config=payload.model,
instruction=payload.instruction,
query=payload.query,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeQuestionClassifierNodeApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginNodeBackwardsInvocation.invoke_question_classifier(
tenant_id=tenant_model.id,
user_id=user_model.id,
query=payload.query,
model_config=payload.model,
classes=payload.classes,
instruction=payload.instruction,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeAppApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeApp)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
response = PluginAppBackwardsInvocation.invoke_app(
app_id=payload.app_id,
user_id=user_model.id,
tenant_id=tenant_model.id,
conversation_id=payload.conversation_id,
query=payload.query,
stream=payload.response_mode == "streaming",
inputs=payload.inputs,
files=payload.files,
)
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
class PluginInvokeEncryptApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeEncrypt)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
"""
encrypt or decrypt data
"""
try:
return BaseBackwardsInvocationResponse(
data=PluginEncrypter.invoke_encrypt(tenant_model, payload)
).model_dump()
except Exception as e:
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
class PluginInvokeSummaryApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSummary)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
try:
return BaseBackwardsInvocationResponse(
data={
"summary": PluginModelBackwardsInvocation.invoke_summary(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
}
).model_dump()
except Exception as e:
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
class PluginUploadFileRequestApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestRequestUploadFile)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
# generate signed url
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")

View File

@ -1,116 +0,0 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from flask import request
from flask_restful import reqparse # type: ignore
from pydantic import BaseModel
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.account import Account, Tenant
from models.model import EndUser
from services.account_service import AccountService
def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
try:
with Session(db.engine) as session:
if not user_id:
user_id = "DEFAULT-USER"
if user_id == "DEFAULT-USER":
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=True if user_id == "DEFAULT-USER" else False,
session_id=user_id,
)
session.add(user_model)
session.commit()
else:
user_model = AccountService.load_user(user_id)
if not user_model:
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
if not user_model:
raise ValueError("user not found")
except Exception:
raise ValueError("user not found")
return user_model
def get_user_tenant(view: Optional[Callable] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
# fetch json body
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json")
kwargs = parser.parse_args()
user_id = kwargs.get("user_id")
tenant_id = kwargs.get("tenant_id")
if not tenant_id:
raise ValueError("tenant_id is required")
if not user_id:
user_id = "DEFAULT-USER"
del kwargs["tenant_id"]
del kwargs["user_id"]
try:
tenant_model = (
db.session.query(Tenant)
.filter(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model
kwargs["user_model"] = get_user(tenant_id, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func):
def decorated_view(*args, **kwargs):
try:
data = request.get_json()
except Exception:
raise ValueError("invalid json")
try:
payload = payload_type(**data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")
kwargs["payload"] = payload
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.wraps import enterprise_inner_api_only
from controllers.inner_api.wraps import inner_api_only
from events.tenant_event import tenant_was_created
from models.account import Account
from services.account_service import TenantService
@ -12,7 +12,7 @@ from services.account_service import TenantService
class EnterpriseWorkspace(Resource):
@setup_required
@enterprise_inner_api_only
@inner_api_only
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
@ -33,7 +33,7 @@ class EnterpriseWorkspace(Resource):
class EnterpriseWorkspaceNoOwnerEmail(Resource):
@setup_required
@enterprise_inner_api_only
@inner_api_only
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")

View File

@ -10,7 +10,7 @@ from extensions.ext_database import db
from models.model import EndUser
def enterprise_inner_api_only(view):
def inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.INNER_API:
@ -18,7 +18,7 @@ def enterprise_inner_api_only(view):
# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
abort(401)
return view(*args, **kwargs)
@ -26,7 +26,7 @@ def enterprise_inner_api_only(view):
return decorated
def enterprise_inner_api_user_auth(view):
def inner_api_user_auth(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.INNER_API:
@ -60,19 +60,3 @@ def enterprise_inner_api_user_auth(view):
return view(*args, **kwargs)
return decorated
def plugin_inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)
# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
abort(404)
return view(*args, **kwargs)
return decorated

View File

@ -15,4 +15,17 @@ api.add_resource(FileApi, "/files/upload")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow
from . import (
app,
audio,
completion,
conversation,
feature,
forgot_password,
login,
message,
passport,
saved_message,
site,
workflow,
)

View File

@ -1,12 +1,18 @@
from flask_restful import marshal_with # type: ignore
from flask import request
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.web import api
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from libs.passport import PassportService
from models.model import App, AppMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
class AppParameterApi(WebApiResource):
@ -42,5 +48,65 @@ class AppMeta(WebApiResource):
return AppService().get_app_meta(app_model)
class AppAccessMode(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("appId", type=str, required=False, location="args")
parser.add_argument("appCode", type=str, required=False, location="args")
args = parser.parse_args()
features = FeatureService.get_system_features()
if not features.webapp_auth.enabled:
return {"accessMode": "public"}
app_id = args.get("appId")
if args.get("appCode"):
app_code = args["appCode"]
app_id = AppService.get_app_id_by_code(app_code)
if not app_id:
raise ValueError("appId or appCode must be provided")
res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
return {"accessMode": res.access_mode}
class AppWebAuthPermission(Resource):
def get(self):
user_id = "visitor"
try:
auth_header = request.headers.get("Authorization")
if auth_header is None:
raise
if " " not in auth_header:
raise
auth_scheme, tk = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise
decoded = PassportService().verify(tk)
user_id = decoded.get("user_id", "visitor")
except Exception as e:
pass
parser = reqparse.RequestParser()
parser.add_argument("appId", type=str, required=True, location="args")
args = parser.parse_args()
app_id = args["appId"]
app_code = AppService.get_app_code_by_id(app_id)
res = True
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
return {"result": res}
api.add_resource(AppParameterApi, "/parameters")
api.add_resource(AppMeta, "/meta")
# webapp auth apis
api.add_resource(AppAccessMode, "/webapp/access-mode")
api.add_resource(AppWebAuthPermission, "/webapp/permission")

View File

@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
class WebSSOAuthRequiredError(BaseHTTPException):
class WebAppAuthRequiredError(BaseHTTPException):
error_code = "web_sso_auth_required"
description = "Web SSO authentication required."
description = "Web app authentication required."
code = 401
class WebAppAuthAccessDeniedError(BaseHTTPException):
error_code = "web_app_access_denied"
description = "You do not have permission to access this web app."
code = 401

View File

@ -0,0 +1,147 @@
import base64
import secrets
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import api
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService
class ForgotPasswordSendEmailApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None
if account is None:
raise AccountNotFound()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
return {"result": "success", "data": token}
class ForgotPasswordCheckApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
user_email = args["email"]
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
class ForgotPasswordResetApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args()
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
raise PasswordMismatchError()
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"])
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
self._update_existing_account(account, password_hashed, salt, session)
else:
raise AccountNotFound()
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@ -0,0 +1,109 @@
import services
from controllers.console.auth.error import (EmailCodeError,
EmailOrPasswordMismatchError,
InvalidEmailError)
from controllers.console.error import AccountBannedError, AccountNotFound
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.web import api
from flask_restful import Resource, reqparse
from jwt import InvalidTokenError # type: ignore
from libs.helper import email
from libs.password import valid_password
from services.account_service import AccountService
from services.webapp_auth_service import WebAppAuthService
class LoginApi(Resource):
"""Resource for web app email/password login."""
@setup_required
@only_edition_enterprise
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()
try:
account = WebAppAuthService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
raise EmailOrPasswordMismatchError()
except services.errors.account.AccountNotFoundError:
raise AccountNotFound()
token = WebAppAuthService.login(account=account)
return {"result": "success", "data": {"access_token": token}}
# class LogoutApi(Resource):
# @setup_required
# def get(self):
# account = cast(Account, flask_login.current_user)
# if isinstance(account, flask_login.AnonymousUserMixin):
# return {"result": "success"}
# flask_login.logout_user()
# return {"result": "success"}
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
@only_edition_enterprise
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = WebAppAuthService.get_user_through_email(args["email"])
if account is None:
raise AccountNotFound()
else:
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
return {"result": "success", "data": token}
class EmailCodeLoginApi(Resource):
@setup_required
@only_edition_enterprise
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args()
user_email = args["email"]
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(user_email)
if not account:
raise AccountNotFound()
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": {"access_token": token}}
api.add_resource(LoginApi, "/login")
# api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

View File

@ -1,16 +1,18 @@
import uuid
from datetime import UTC, datetime, timedelta
from flask import request
from flask_restful import Resource # type: ignore
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from controllers.web import api
from controllers.web.error import WebSSOAuthRequiredError
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from flask import request
from flask_restful import Resource
from libs.passport import PassportService
from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
from werkzeug.exceptions import NotFound, Unauthorized
class PassportResource(Resource):
@ -19,13 +21,23 @@ class PassportResource(Resource):
def get(self):
system_features = FeatureService.get_system_features()
app_code = request.headers.get("X-App-Code")
web_app_access_token = request.args.get("web_app_access_token")
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
raise WebSSOAuthRequiredError()
# exchange token for enterprise logined web user
enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
if enterprise_user_decoded:
# a web user has already logged in, exchange a token for this app without redirecting to the login page
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
)
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if not app_settings or not app_settings.access_mode == "public":
raise WebAppAuthRequiredError()
# get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
@ -65,6 +77,128 @@ class PassportResource(Resource):
api.add_resource(PassportResource, "/passport")
def decode_enterprise_webapp_user_id(jwt_token: str | None):
"""
Decode the enterprise user session from the Authorization header.
"""
if not jwt_token:
return None
decoded = PassportService().verify(jwt_token)
source = decoded.get("token_source")
if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
"""
Exchange a token for an existing web user session.
"""
user_id = enterprise_user_decoded.get("user_id")
end_user_id = enterprise_user_decoded.get("end_user_id")
session_id = enterprise_user_decoded.get("session_id")
user_auth_type = enterprise_user_decoded.get("auth_type")
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
if not site:
raise NotFound()
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
if app_auth_type == WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None
if end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if session_id:
end_user = (
db.session.query(EndUser)
.filter(
EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
)
.first()
)
if not end_user:
if not session_id:
raise NotFound("Missing session_id for existing web user.")
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=True,
session_id=session_id,
)
db.session.add(end_user)
db.session.commit()
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp = int(exp_dt.timestamp())
payload = {
"iss": site.id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": user_id,
"end_user_id": end_user.id,
"auth_type": user_auth_type,
"granted_at": int(datetime.now(UTC).timestamp()),
"token_source": "webapp",
"exp": exp,
}
token: str = PassportService().issue(payload)
return {
"access_token": token,
}
def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id")
end_user = None
if user_id:
end_user = db.session.query(EndUser).filter(
EndUser.app_id == app_model.id, EndUser.session_id == user_id
).first()
if not end_user:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=True,
session_id=generate_session_id(),
)
db.session.add(end_user)
db.session.commit()
payload = {
"iss": site.app_id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"end_user_id": end_user.id,
}
tk = PassportService().issue(payload)
return {
"access_token": tk,
}
def generate_session_id():
"""
Generate a unique session ID.

View File

@ -1,15 +1,18 @@
from datetime import UTC, datetime
from functools import wraps
from controllers.web.error import (WebAppAuthAccessDeniedError,
WebAppAuthRequiredError)
from extensions.ext_database import db
from flask import request
from flask_restful import Resource # type: ignore
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebSSOAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService
from services.enterprise.enterprise_service import (EnterpriseService,
WebAppSettings)
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
def validate_jwt_token(view=None):
@ -45,7 +48,8 @@ def decode_jwt_token():
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first()
app_id = decoded.get("app_id")
app_model = db.session.query(App).filter(App.id == app_id).first()
site = db.session.query(Site).filter(Site.code == app_code).first()
if not app_model:
raise NotFound()
@ -53,39 +57,90 @@ def decode_jwt_token():
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
end_user_id = decoded.get("end_user_id")
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
raise NotFound()
_validate_web_sso_token(decoded, system_features, app_code)
# for enterprise webapp auth
app_web_auth_enabled = False
webapp_settings = None
if system_features.webapp_auth.enabled:
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if not webapp_settings:
raise NotFound("Web app settings not found.")
app_web_auth_enabled = webapp_settings.access_mode != "public"
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(
decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings
)
return app_model, end_user
except Unauthorized as e:
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
raise WebSSOAuthRequiredError()
if system_features.webapp_auth.enabled:
if not app_code:
raise Unauthorized("Please re-login to access the web app.")
app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
)
if app_web_auth_enabled:
raise WebAppAuthRequiredError()
raise Unauthorized(e.description)
def _validate_web_sso_token(decoded, system_features, app_code):
app_web_sso_enabled = False
# Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
source = decoded.get("token_source")
if not source or source != "sso":
raise WebSSOAuthRequiredError()
# Check if SSO is not enforced for web, and if the token source is SSO,
# raise an error and redirect to normal passport login
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
# Check if authentication is enforced for web app, and if the token source is not webapp,
# raise an error and redirect to login
if system_webapp_auth_enabled and app_web_auth_enabled:
source = decoded.get("token_source")
if source and source == "sso":
raise Unauthorized("sso token expired.")
if not source or source != "webapp":
raise WebAppAuthRequiredError()
# Check if authentication is not enforced for web, and if the token source is webapp,
# raise an error and redirect to normal passport login
if not system_webapp_auth_enabled or not app_web_auth_enabled:
source = decoded.get("token_source")
if source and source == "webapp":
raise Unauthorized("webapp token expired.")
def _validate_user_accessibility(
decoded,
app_code,
app_web_auth_enabled: bool,
system_webapp_auth_enabled: bool,
webapp_settings: WebAppSettings | None,
):
if system_webapp_auth_enabled and app_web_auth_enabled:
# Check if the user is allowed to access the web app
user_id = decoded.get("user_id")
if not user_id:
raise WebAppAuthRequiredError()
if not webapp_settings:
raise WebAppAuthRequiredError("Web app settings not found.")
if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode):
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
raise WebAppAuthAccessDeniedError()
auth_type = decoded.get("auth_type")
granted_at = decoded.get("granted_at")
if not auth_type:
raise WebAppAuthAccessDeniedError("Missing auth_type in the token.")
if not granted_at:
raise WebAppAuthAccessDeniedError("Missing granted_at in the token.")
# check if sso has been updated
if auth_type == "external":
last_update_time = EnterpriseService.get_app_sso_settings_last_update_time()
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
elif auth_type == "internal":
last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time()
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
class WebApiResource(Resource):

View File

@ -1,6 +1,7 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity
@ -31,16 +32,19 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
ToolRuntimeVariablePool,
)
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought, MessageFile
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@ -58,9 +62,11 @@ class BaseAgentRunner(AppRunner):
queue_manager: AppQueueManager,
message: Message,
user_id: str,
model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance,
) -> None:
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
@ -73,6 +79,8 @@ class BaseAgentRunner(AppRunner):
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
# init callback
@ -133,10 +141,11 @@ class BaseAgentRunner(AppRunner):
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from,
)
assert tool_entity.entity.description
tool_entity.load_variables(self.variables_pool)
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.entity.description.llm,
description=tool_entity.description.llm if tool_entity.description else "",
parameters={
"type": "object",
"properties": {},
@ -144,7 +153,7 @@ class BaseAgentRunner(AppRunner):
},
)
parameters = tool_entity.get_merged_runtime_parameters()
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
@ -177,11 +186,9 @@ class BaseAgentRunner(AppRunner):
"""
convert dataset retriever tool to prompt message tool
"""
assert tool.entity.description
prompt_tool = PromptMessageTool(
name=tool.entity.identity.name,
description=tool.entity.description.llm,
name=tool.identity.name if tool.identity else "unknown",
description=tool.description.llm if tool.description else "",
parameters={
"type": "object",
"properties": {},
@ -227,7 +234,8 @@ class BaseAgentRunner(AppRunner):
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
if dataset_tool.identity is not None:
tool_instances[dataset_tool.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
@ -312,23 +320,24 @@ class BaseAgentRunner(AppRunner):
def save_agent_thought(
self,
agent_thought: MessageAgentThought,
tool_name: str | None,
tool_input: Union[str, dict, None],
thought: str | None,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict, None],
tool_invoke_meta: Union[str, dict, None],
answer: str | None,
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage | None = None,
):
"""
Save agent thought
"""
updated_agent_thought = (
queried_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
)
if not updated_agent_thought:
raise ValueError("agent thought not found")
if not queried_thought:
raise ValueError(f"Agent thought {agent_thought.id} not found")
agent_thought = queried_thought
if thought:
agent_thought.thought = thought
@ -340,39 +349,39 @@ class BaseAgentRunner(AppRunner):
if isinstance(tool_input, dict):
try:
tool_input = json.dumps(tool_input, ensure_ascii=False)
except Exception:
except Exception as e:
tool_input = json.dumps(tool_input)
updated_agent_thought.tool_input = tool_input
agent_thought.tool_input = tool_input
if observation:
if isinstance(observation, dict):
try:
observation = json.dumps(observation, ensure_ascii=False)
except Exception:
except Exception as e:
observation = json.dumps(observation)
updated_agent_thought.observation = observation
agent_thought.observation = observation
if answer:
agent_thought.answer = answer
if messages_ids is not None and len(messages_ids) > 0:
updated_agent_thought.message_files = json.dumps(messages_ids)
agent_thought.message_files = json.dumps(messages_ids)
if llm_usage:
updated_agent_thought.message_token = llm_usage.prompt_tokens
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
updated_agent_thought.answer_token = llm_usage.completion_tokens
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
updated_agent_thought.tokens = llm_usage.total_tokens
updated_agent_thought.total_price = llm_usage.total_price
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
agent_thought.message_unit_price = llm_usage.prompt_unit_price
agent_thought.answer_token = llm_usage.completion_tokens
agent_thought.answer_price_unit = llm_usage.completion_price_unit
agent_thought.answer_unit_price = llm_usage.completion_unit_price
agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price
# check if tool labels is not empty
labels = updated_agent_thought.tool_labels or {}
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else []
for tool in tools:
if not tool:
continue
@ -383,20 +392,42 @@ class BaseAgentRunner(AppRunner):
else:
labels[tool] = {"en_US": tool, "zh_Hans": tool}
updated_agent_thought.tool_labels_str = json.dumps(labels)
agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict):
try:
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
except Exception:
except Exception as e:
tool_invoke_meta = json.dumps(tool_invoke_meta)
updated_agent_thought.tool_meta_str = tool_invoke_meta
agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit()
db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
queried_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
if not queried_variables:
return
db_variables = queried_variables
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize agent history
@ -433,11 +464,11 @@ class BaseAgentRunner(AppRunner):
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
except Exception as e:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
except Exception as e:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
for tool in tools:
@ -484,11 +515,7 @@ class BaseAgentRunner(AppRunner):
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
else:
file_extra_config = None
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if not file_extra_config:
return UserPromptMessage(content=message.query)

View File

@ -1,6 +1,6 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Generator, Mapping
from typing import Any, Optional
from core.agent.base_agent_runner import BaseAgentRunner
@ -18,8 +18,8 @@ from core.model_runtime.entities.message_entities import (
)
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
from models.model import Message
@ -27,11 +27,11 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]
_historic_prompt_messages: list[PromptMessage] | None = None
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
_instruction: str = "" # FIXME this must be str for now
_query: str | None = None
_prompt_messages_tools: list[PromptMessageTool] = []
def run(
self,
@ -42,7 +42,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query)
@ -55,19 +54,17 @@ class CotAgentRunner(BaseAgentRunner, ABC):
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
assert app_config.agent
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template or ""
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
@ -107,7 +104,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
@ -119,7 +115,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[],
)
usage_dict: dict[str, Optional[LLMUsage]] = {}
if not isinstance(chunks, Generator):
raise ValueError("Expected streaming response from LLM")
# check llm result
if not chunks:
raise ValueError("failed to invoke llm")
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
@ -139,25 +142,25 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump())
if scratchpad.agent_response is not None:
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk
if scratchpad.agent_response is not None:
scratchpad.agent_response += chunk
if scratchpad.thought is not None:
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
if scratchpad.thought is not None:
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
if self._agent_scratchpad is not None:
self._agent_scratchpad.append(scratchpad)
# get llm usage
if "usage" in usage_dict:
@ -252,6 +255,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
answer=final_answer,
messages_ids=[],
)
if self.variables_pool is not None and self.db_variables_pool is not None:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@ -269,7 +274,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: Mapping[str, Tool],
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
@ -309,7 +314,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
)
# publish files
for message_file_id in message_files:
for message_file_id, save_as in message_files:
if save_as is not None and self.variables_pool:
# FIXME the save_as type is confusing, it should be a string or not
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
@ -332,7 +341,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
for key, value in inputs.items():
try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception:
except Exception as e:
continue
return instruction
@ -369,7 +378,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] | None = None
self, current_session_messages: Optional[list[PromptMessage]] = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
@ -381,7 +390,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
assert isinstance(message.content, str)
if not isinstance(message.content, str | None):
raise NotImplementedError("expected str type")
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
@ -400,8 +410,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
except:
pass
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)
if not current_scratchpad:
continue
if isinstance(message.content, str):
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")

View File

@ -19,8 +19,8 @@ class CotChatAgentRunner(CotAgentRunner):
"""
Organize system prompt
"""
assert self.app_config.agent
assert self.app_config.agent.prompt
if not self.app_config.agent:
raise ValueError("Agent configuration is not set")
prompt_entity = self.app_config.agent.prompt
if not prompt_entity:
@ -83,10 +83,8 @@ class CotChatAgentRunner(CotAgentRunner):
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n"

View File

@ -1,21 +1,18 @@
from enum import StrEnum
from typing import Any, Optional, Union
from enum import Enum
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: ToolProviderType
provider_type: Literal["builtin", "api", "workflow"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
plugin_unique_identifier: str | None = None
class AgentPromptEntity(BaseModel):
@ -69,7 +66,7 @@ class AgentEntity(BaseModel):
Agent Entity.
"""
class Strategy(StrEnum):
class Strategy(Enum):
"""
Agent Strategy.
"""
@ -81,13 +78,5 @@ class AgentEntity(BaseModel):
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: Optional[list[AgentToolEntity]] = None
tools: list[AgentToolEntity] | None = None
max_iteration: int = 5
class AgentInvokeMessage(ToolInvokeMessage):
"""
Agent Invoke Message.
"""
pass

View File

@ -46,20 +46,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
assert app_config.agent
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
final_answer = ""
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
@ -86,7 +84,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
@ -109,7 +106,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = None
if isinstance(chunks, Generator):
if self.stream_tool_call and isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
@ -126,7 +123,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError:
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
@ -142,7 +139,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = chunk.delta.usage
yield chunk
else:
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
@ -153,7 +150,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError:
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
@ -185,6 +182,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
usage=result.usage,
),
)
else:
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
@ -243,12 +242,15 @@ class FunctionCallAgentRunner(BaseAgentRunner):
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id,
conversation_id=self.conversation.id,
)
# publish files
for message_file_id in message_files:
for message_file_id, save_as in message_files:
if save_as:
if self.variables_pool:
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as
)
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
@ -300,6 +302,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
if self.variables_pool and self.db_variables_pool:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@ -330,7 +334,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract tool calls from llm result chunk
@ -353,7 +359,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
@ -376,7 +382,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
def _init_system_message(
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
) -> list[PromptMessage]:
"""
Initialize system message
"""

View File

@ -1,89 +0,0 @@
import enum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from core.entities.parameter_entities import CommonParameterType
from core.plugin.entities.parameters import (
PluginParameter,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolProviderIdentity,
)
class AgentStrategyProviderIdentity(ToolProviderIdentity):
"""
Inherits from ToolProviderIdentity, without any additional fields.
"""
pass
class AgentStrategyParameter(PluginParameter):
class AgentStrategyParameterType(enum.StrEnum):
"""
Keep all the types from PluginParameterType
"""
STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
class AgentStrategyProviderEntity(BaseModel):
identity: AgentStrategyProviderIdentity
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
class AgentStrategyIdentity(ToolIdentity):
"""
Inherits from ToolIdentity, without any additional fields.
"""
pass
class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
return v or []
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
strategies: list[AgentStrategyEntity] = Field(default_factory=list)

View File

@ -1,42 +0,0 @@
from abc import ABC, abstractmethod
from collections.abc import Generator, Sequence
from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter
class BaseAgentStrategy(ABC):
"""
Agent Strategy
"""
def invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
Get the parameters for the agent strategy.
"""
return []
@abstractmethod
def _invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass

View File

@ -1,59 +0,0 @@
from collections.abc import Generator, Sequence
from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.manager.agent import PluginAgentManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
class PluginAgentStrategy(BaseAgentStrategy):
"""
Agent Strategy
"""
tenant_id: str
declaration: AgentStrategyEntity
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
self.tenant_id = tenant_id
self.declaration = declaration
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters
def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
"""
Initialize the parameters for the agent strategy.
"""
for parameter in self.declaration.parameters:
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
return params
def _invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
manager = PluginAgentManager()
initialized_params = self.initialize_parameters(params)
params = convert_parameters_to_plugin_format(initialized_params)
yield from manager.invoke(
tenant_id=self.tenant_id,
user_id=user_id,
agent_provider=self.declaration.identity.provider,
agent_strategy=self.declaration.identity.name,
agent_params=params,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)

View File

@ -4,8 +4,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.provider_manager import ProviderManager
@ -64,14 +63,14 @@ class ModelConfigConverter:
stop = completion_params["stop"]
del completion_params["stop"]
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
# get model mode
model_mode = model_config.mode
if not model_mode:
model_mode = LLMMode.CHAT.value
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@ -2,9 +2,8 @@ from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.model_runtime.model_providers import model_provider_factory
from core.provider_manager import ProviderManager
@ -54,18 +53,9 @@ class ModelConfigManager:
raise ValueError("model must be of object type")
# model.provider
model_provider_factory = ModelProviderFactory(tenant_id)
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"]:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
if "/" not in config["model"]["provider"]:
config["model"]["provider"] = (
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
)
if config["model"]["provider"] not in model_provider_names:
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name

View File

@ -37,6 +37,17 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
@ -54,31 +65,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
Generate App response.
@ -156,8 +156,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_run_id=workflow_run_id,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate(
workflow=workflow,
@ -169,14 +167,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
)
def single_iteration_generate(
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.
@ -213,8 +205,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate(
workflow=workflow,
@ -234,7 +224,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.

View File

@ -56,7 +56,7 @@ def _process_future(
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
@ -67,7 +67,7 @@ class AppGeneratorTTSPublisher:
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices(language=language)
self.voices = self.model_instance.get_tts_voices()
values = [voice.get("value") for voice in self.voices]
self.voice = voice
if not voice or voice not in values:

View File

@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Generator
from typing import Any, cast
@ -57,7 +58,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
) -> Generator[str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -83,12 +84,12 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
) -> Generator[str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -122,4 +123,4 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)

View File

@ -17,7 +17,6 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
@ -220,9 +219,7 @@ class AdvancedChatAppGenerateTaskPipeline:
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
@ -250,7 +247,7 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception:
except Exception as e:
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
break
if tts_publisher:
@ -384,6 +381,7 @@ class AdvancedChatAppGenerateTaskPipeline:
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if node_finish_resp:
yield node_finish_resp
@ -642,10 +640,6 @@ class AdvancedChatAppGenerateTaskPipeline:
session.commit()
yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue

View File

@ -1,4 +1,3 @@
import contextvars
import logging
import threading
import uuid
@ -30,6 +29,17 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
@ -41,17 +51,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
@ -61,7 +60,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
@ -71,7 +70,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
):
"""
Generate App response.
@ -183,7 +182,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": contextvars.copy_context(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@ -208,7 +206,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(
self,
flask_app: Flask,
context: contextvars.Context,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
@ -223,9 +220,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param message_id: message ID
:return:
"""
for var, val in context.items():
var.set(val)
with flask_app.app_context():
try:
# get conversation and message

View File

@ -8,16 +8,18 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message
from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@ -53,20 +55,6 @@ class AgentChatAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query,
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
@ -84,8 +72,8 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
inputs=inputs,
files=files,
query=query,
memory=memory,
)
@ -97,8 +85,8 @@ class AgentChatAppRunner(AppRunner):
app_id=app_record.id,
tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity,
inputs=dict(inputs),
query=query or "",
inputs=inputs,
query=query,
message_id=message.id,
)
except ModerationError as e:
@ -154,9 +142,9 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
inputs=inputs,
files=files,
query=query,
memory=memory,
)
@ -171,7 +159,16 @@ class AgentChatAppRunner(AppRunner):
return
agent_entity = app_config.agent
assert agent_entity is not None
if not agent_entity:
raise ValueError("Agent entity not found")
# load tool variables
tool_conversation_variables = self._load_tool_variables(
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
)
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
# init model instance
model_instance = ModelInstance(
@ -182,9 +179,9 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
inputs=inputs,
files=files,
query=query,
memory=memory,
)
@ -232,6 +229,8 @@ class AgentChatAppRunner(AppRunner):
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance,
)
@ -248,3 +247,73 @@ class AgentChatAppRunner(AppRunner):
stream=application_generate_entity.stream,
agent=True,
)
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables | None = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
if tool_variables:
# save tool variables to session, so that we can update it later
db.session.add(tool_variables)
else:
# create new tool variables
tool_variables = ToolConversationVariables(
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
variables_str="[]",
)
db.session.add(tool_variables)
db.session.commit()
return tool_variables
def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
"""
convert db variables to tool variables
"""
return ToolRuntimeVariablePool(
**{
"conversation_id": db_variables.conversation_id,
"user_id": db_variables.user_id,
"tenant_id": db_variables.tenant_id,
"pool": db_variables.variables,
}
)
def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
)
all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
)

View File

@ -1,9 +1,9 @@
import json
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@ -51,9 +51,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
def convert_stream_full_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -79,12 +80,13 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
def convert_stream_simple_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -116,4 +118,4 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)

View File

@ -14,15 +14,21 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
cls,
response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]],
invoke_from: InvokeFrom,
) -> Mapping[str, Any] | Generator[str, None, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_full_response(response)
def _generate_full_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_full_response(response):
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f"data: {chunk}\n\n"
return _generate_full_response()
else:
@ -30,8 +36,12 @@ class AppGenerateResponseConverter(ABC):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
def _generate_simple_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_simple_response(response):
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f"data: {chunk}\n\n"
return _generate_simple_response()
@ -49,14 +59,14 @@ class AppGenerateResponseConverter(ABC):
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
raise NotImplementedError
@classmethod

View File

@ -1,6 +1,5 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from core.app.app_config.entities import VariableEntityType
from core.file import File, FileUploadConfig
@ -139,21 +138,3 @@ class BaseAppGenerator:
if isinstance(value, str):
return value.replace("\x00", "")
return value
@classmethod
def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]):
"""
Convert messages into event stream
"""
if isinstance(generator, dict):
return generator
else:
def gen():
for message in generator:
if isinstance(message, (Mapping, dict)):
yield f"data: {json.dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"
return gen()

View File

@ -2,7 +2,7 @@ import queue
import time
from abc import abstractmethod
from enum import Enum
from typing import Any, Optional
from typing import Any
from sqlalchemy.orm import DeclarativeMeta
@ -115,7 +115,7 @@ class AppQueueManager:
Set task stop flag
:return:
"""
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None:
return

View File

@ -15,10 +15,8 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
@ -31,106 +29,6 @@ if TYPE_CHECKING:
class AppRunner:
def get_pre_calculate_rest_tokens(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
) -> int:
"""
Get pre calculate rest tokens
:param app_record: app record
:param model_config: model config entity
:param prompt_template_entity: prompt template entity
:param inputs: inputs
:param files: files
:param query: query
:return:
"""
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
# get prompt messages without memory and context
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
query=query,
)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return rest_tokens
def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(
self,
app_record: App,

View File

@ -38,7 +38,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...
) -> Generator[str, None, None]: ...
@overload
def generate(
@ -58,7 +58,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
def generate(
self,
@ -67,7 +67,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
):
"""
Generate App response.

View File

@ -50,20 +50,6 @@ class ChatAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
@ -194,9 +180,6 @@ class ChatAppRunner(AppRunner):
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,

View File

@ -1,9 +1,9 @@
import json
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@ -52,8 +52,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -79,12 +80,13 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -116,4 +118,4 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)

View File

@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str | Mapping[str, Any], None, None]: ...
) -> Generator[str, None, None]: ...
@overload
def generate(
@ -56,8 +56,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
):
"""
Generate App response.
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
"""
Generate App response.

View File

@ -43,20 +43,6 @@ class CompletionAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
prompt_messages, stop = self.organize_prompt_messages(
@ -152,9 +138,6 @@ class CompletionAppRunner(AppRunner):
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,

View File

@ -1,9 +1,9 @@
import json
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse,
CompletionAppStreamResponse,
ErrorStreamResponse,
@ -51,8 +51,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -77,12 +78,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -113,4 +115,4 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)

View File

@ -36,13 +36,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Generator[Mapping | str, None, None]: ...
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...
@overload
def generate(
@ -50,12 +50,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any]: ...
@overload
@ -64,26 +64,26 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
@ -124,10 +124,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
trace_manager=trace_manager,
workflow_run_id=workflow_run_id,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate(
app_model=app_model,
@ -149,18 +146,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
) -> Mapping[str, Any] | Generator[str, None, None]:
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@ -199,10 +185,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
user: Account,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.
@ -238,8 +224,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate(
app_model=app_model,

View File

@ -1,9 +1,9 @@
import json
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
@ -36,8 +36,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -61,12 +62,13 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -92,4 +94,4 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
yield json.dumps(response_chunk)

View File

@ -13,7 +13,6 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
@ -191,9 +190,7 @@ class WorkflowAppGenerateTaskPipeline:
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
@ -530,10 +527,6 @@ class WorkflowAppGenerateTaskPipeline:
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue

View File

@ -5,7 +5,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -28,7 +27,6 @@ from core.app.entities.queue_entities import (
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
@ -241,7 +239,6 @@ class WorkflowBasedAppRunner(AppRunner):
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
)
)
elif isinstance(event, NodeRunSucceededEvent):
@ -376,20 +373,6 @@ class WorkflowBasedAppRunner(AppRunner):
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, AgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
node_id=event.node_id,
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(

View File

@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None

View File

@ -6,7 +6,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
@ -41,7 +41,6 @@ class QueueEvent(StrEnum):
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error"
PING = "ping"
STOP = "stop"
@ -281,7 +280,6 @@ class QueueNodeStartedEvent(AppQueueEvent):
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None
class QueueNodeSucceededEvent(AppQueueEvent):
@ -317,23 +315,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None
class QueueAgentLogEvent(AppQueueEvent):
"""
QueueAgentLogEvent entity
"""
event: QueueEvent = QueueEvent.AGENT_LOG
id: str
label: str
node_execution_id: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
node_id: str
class QueueNodeRetryEvent(QueueNodeStartedEvent):
"""QueueNodeRetryEvent entity"""

Some files were not shown because too many files have changed in this diff Show More