Compare commits

...

306 Commits

Author SHA1 Message Date
dd7efb11ff Merge branch 'feat/memory-orchestration-be' into feat/memory-orchestration-be-dev-env 2025-10-16 17:15:31 +08:00
bfba8bec2d fix: workflow fields 2025-10-16 17:15:05 +08:00
8aa4db0c77 Merge branch 'feat/memory-orchestration-be' into feat/memory-orchestration-be-dev-env 2025-10-16 16:33:52 +08:00
65a3646ce7 fix: error handling with model validation 2025-10-16 16:33:33 +08:00
106f0fba0b Merge branch 'feat/memory-orchestration-be' into feat/memory-orchestration-be-dev-env
# Conflicts:
#	api/controllers/console/app/workflow.py
#	api/services/workflow_service.py
2025-10-15 17:00:15 +08:00
cb73335599 chore: run ruff 2025-10-15 16:57:18 +08:00
f4fa57dac9 fix: store memory_blocks in correct field 2025-10-15 16:56:12 +08:00
06f364f2c8 Merge branch 'feat/memory-orchestration-be' into feat/memory-orchestration-be-dev-env 2025-10-15 14:43:45 +08:00
7ca06931ec fix: unify memory variable in VariablePool 2025-10-15 14:39:05 +08:00
db83b54a88 Merge branch 'feat/memory-orchestration-be' into feat/memory-orchestration-be-dev-env 2025-10-11 16:27:58 +08:00
f4567fbf9e fix: fix circular ref 2025-10-11 16:27:40 +08:00
8fd088754a fix: fix circular ref 2025-10-11 16:16:51 +08:00
e7d63a9fa3 fix: fix circular ref 2025-10-11 16:15:32 +08:00
a1e3a72274 chore: add database migration file 2025-10-11 15:38:05 +08:00
1a4600ce77 Merge remote-tracking branch 'origin/deploy/dev' into feat/memory-orchestration-be-dev-env
# Conflicts:
#	api/models/__init__.py
#	api/uv.lock
2025-10-11 15:01:26 +08:00
61d9428064 refactor: fix basedpyright error 2025-10-10 18:47:16 +08:00
f6038a4557 Merge branch 'main' into feat/memory-orchestration-be 2025-10-10 18:43:59 +08:00
c09d205776 Update deploy-dev.yml 2025-10-10 15:55:35 +08:00
b4f9698289 add trial api 2025-10-10 14:15:56 +08:00
c9a2dc0b13 Merge branch 'feat/tax-text' into deploy/dev 2025-10-10 14:03:23 +08:00
b6114266af refactor(billing): enhance pricing footer layout with conditional classnames for better responsiveness 2025-10-10 14:03:02 +08:00
33b576b9d5 add trial api 2025-10-10 13:59:14 +08:00
63de2cc7a0 feat(billing): refactor pricing footer to conditionally display tax information based on category 2025-10-10 12:23:25 +08:00
c9c3aa0810 mr app-trial 2025-10-10 11:30:53 +08:00
dee4399060 feat(billing): refactor pricing footer to conditionally display tax information based on category 2025-10-10 10:49:51 +08:00
d3ea98037e Merge branch 'feat/tax-text' into deploy/dev 2025-10-10 10:01:00 +08:00
2c408445ff Merge remote-tracking branch 'origin/main' into deploy/dev 2025-10-10 10:00:50 +08:00
8d2b5c5464 feat(billing): refactor pricing footer to conditionally display tax information based on category 2025-10-10 10:00:29 +08:00
97b5d4bba1 Merge remote-tracking branch 'origin/main' into feat/tax-text 2025-10-10 09:59:57 +08:00
68c7e43d8c fix user uploaded avatar display incorrect 2025-10-10 09:27:50 +08:00
e0af930acd comment author avatar is the first avatar 2025-10-10 09:27:50 +08:00
dfc8bc4aec only can edit own replies 2025-10-10 09:27:50 +08:00
63b4bca7d8 fix missing i18n 2025-10-10 09:27:50 +08:00
517f8aafdc fix switch to cursor mode comment input still exists 2025-10-10 09:27:50 +08:00
d05ba90779 comment reply auto scoll down to bottom 2025-10-10 09:27:50 +08:00
b6620c1f42 fix comment hover the variable panel 2025-10-10 09:27:50 +08:00
b60e1f4222 Merge branch 'feat/tax-text' into deploy/dev 2025-10-09 18:28:27 +08:00
4df606f439 fix: merge main 2025-10-09 18:26:10 +08:00
2a5a497f15 Merge remote-tracking branch 'origin/main' into feat/tax-text 2025-10-09 18:08:32 +08:00
3f1da39aee feat(billing): add tax information tooltips in pricing footer. 2025-10-09 18:08:02 +08:00
a52edc6cc1 fix version not display 2025-10-09 15:08:28 +08:00
c367f80ec5 Merge branch 'main' into feat/memory-orchestration-be 2025-10-09 15:01:03 +08:00
da353a42da fix import error 2025-10-09 10:56:48 +08:00
9b0f172f91 Merge branch 'p284' into deploy/dev 2025-10-09 09:39:06 +08:00
85dfc013ea fix default comment icon 2025-10-09 09:23:36 +08:00
5a1fae1171 add leader session more check 2025-10-09 09:23:04 +08:00
33d4c95470 can update comment position 2025-10-05 10:17:04 +08:00
659cbc05a9 fix mention-input in the bottom of the browser 2025-10-04 21:24:27 +08:00
6ce65de2cd fix merged main issues 2025-10-04 21:11:59 +08:00
93b2eb3ff6 Merge remote-tracking branch 'myori/main' into p284 2025-10-04 15:28:29 +08:00
bf71300635 improve comment cursor move 2025-10-04 14:36:10 +08:00
37ecd4a0bc fix @ input problem 2025-10-04 13:39:00 +08:00
827a1b181b fix comment icon position 2025-10-04 13:25:59 +08:00
c4e7cb75cd cache the mentioned users 2025-10-04 11:22:02 +08:00
98e4bfcda8 click comment icon not switch to comment mode 2025-10-03 23:36:56 +08:00
ee48ca7671 fix default comment icon 2025-09-30 15:23:43 +08:00
4ba6de1116 add leader session more check 2025-09-29 14:01:42 +08:00
bfbe636555 fix docker file websocket mode 2025-09-29 13:35:10 +08:00
930fdc8fb4 fix plugin detail panel display in tool list 2025-09-29 11:23:44 +08:00
fc90a8fb32 fix plugin detail panel display in tool list 2025-09-29 10:46:55 +08:00
791f33fd0b Merge branch 'main' into feat/memory-orchestration-be 2025-09-28 22:41:24 +08:00
1e0a3b163e refactor: fix ruff 2025-09-28 22:41:07 +08:00
bb1f1a56a5 feat: update MemoryListApi response format with ChatflowConversationMetadata 2025-09-28 22:36:10 +08:00
15be85514d fix: chatflow message visibility from index 2025-09-28 21:20:37 +08:00
80cb85b845 fix docker file websocket mode 2025-09-26 15:18:10 +08:00
0b131f1a8c [autofix.ci] apply automated fixes 2025-09-26 06:44:39 +00:00
e51f2d68cb fix: create paid provider auto 2025-09-26 14:31:53 +08:00
a6d4bf3399 sync children node data 2025-09-26 14:16:38 +08:00
113aa4ae08 fix add child node resize parent node size 2025-09-26 14:16:38 +08:00
5b40bf6d4e http node data sync 2025-09-26 14:16:38 +08:00
06ad8efd89 sync the prompt editor 2025-09-26 14:16:38 +08:00
8513afbcf6 fix opened panel be affected 2025-09-26 14:16:38 +08:00
54ae43ef47 sync children node data 2025-09-26 14:07:34 +08:00
7a74b5ee3e fix add child node resize parent node size 2025-09-26 14:04:50 +08:00
0641773395 fix: add paid quota error for init_anthropic 2025-09-26 13:11:14 +08:00
d745c2e8e3 add paid credit 2025-09-26 12:49:35 +08:00
e974c696f7 add paid credit 2025-09-26 12:49:26 +08:00
2ff280c4bf add credit pool sys 2025-09-26 11:18:28 +08:00
0e9d43d605 http node data sync 2025-09-26 11:13:20 +08:00
cc54363c27 sync the prompt editor 2025-09-26 10:48:00 +08:00
c41140d654 add credit pool sys 2025-09-26 10:45:24 +08:00
89affe3139 fix opened panel be affected 2025-09-26 09:20:33 +08:00
f7fd065bee fix pnpm lock 2025-09-25 17:05:20 +08:00
96c7c86e9d Merge branch 'p284' into deploy/dev 2025-09-25 16:58:45 +08:00
2c4977dbb1 fix bug 2025-09-25 16:56:06 +08:00
e240175116 sync nodes 2025-09-25 16:31:46 +08:00
2398ed6fe8 fix update env api update time error 2025-09-25 16:28:33 +08:00
a8420ac33c add fragment to prevent list missing key 2025-09-25 09:52:08 +08:00
8470be6411 improve delete comment i18n 2025-09-25 09:41:59 +08:00
3d6295c622 refactor delete comment and reply 2025-09-25 09:35:46 +08:00
ff2f7206f3 bump nextjs to 15.5 and turbopack for development mode (#24346)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
2025-09-25 09:10:09 +08:00
b937fc8978 app online user list 2025-09-24 17:03:33 +08:00
86a9a51952 add comment preview 2025-09-24 12:54:54 +08:00
4188c9a1dd fix dark theme 2025-09-24 10:08:33 +08:00
8833fee232 feat: move version update logic out of save_memory 2025-09-23 23:17:34 +08:00
5bf642c3f9 feat: expose version to MemoryBlock 2025-09-23 23:09:45 +08:00
8c00f89e36 add icon to zoom2fit 2025-09-23 22:22:28 +08:00
9e8ac5c96b refactor cursor and add hide comment 2025-09-23 22:13:02 +08:00
3d7d4182a6 feat: add endpoints to delete memory 2025-09-23 19:07:37 +08:00
75c221038d feat: add endpoints to __init__.py 2025-09-23 18:35:11 +08:00
b7b5b0b8d0 Merge branch 'main' into feat/memory-orchestration-be 2025-09-23 17:43:52 +08:00
05a67f4716 add display/hide collaborator cursors 2025-09-23 17:37:40 +08:00
6eab6a675c feat: add created_by to memory blocks 2025-09-23 17:35:36 +08:00
f49476a206 add show/hide minimap 2025-09-23 17:20:41 +08:00
c1e9c56e25 fix style 2025-09-23 17:19:36 +08:00
d5dd73cacf add i18n for comment 2025-09-23 16:19:04 +08:00
21f7a49b4e fix restore page crash 2025-09-23 15:44:57 +08:00
716ac04e13 add comment shortcut 2025-09-23 15:40:53 +08:00
c28a32fc47 fix handleModeComment 2025-09-23 15:35:28 +08:00
31cba28e8a improve comment cursor icon 2025-09-23 15:28:22 +08:00
48cd7e6481 input comment should not cancel comment mode 2025-09-23 14:48:31 +08:00
47aba1c9f9 fix style 2025-09-23 14:41:34 +08:00
d94e598a89 revert: remove memory database migration 2025-09-23 14:19:40 +08:00
0f3f8bc0d9 make mention input can display name different color 2025-09-23 11:38:38 +08:00
e0df12c212 fix mentioned names color 2025-09-23 11:24:17 +08:00
eb448d9bb8 fix avatar background color 2025-09-23 11:09:02 +08:00
0ba77f13db fix avatar inset 2025-09-23 10:46:18 +08:00
f0a2eb843c fix user cursor should not over the panel 2025-09-23 10:35:16 +08:00
28acb70118 feat: add edited_by_user field 2025-09-22 18:37:54 +08:00
7c35aaa99d refactor: remove MemoryBlockWithVisibility 2025-09-22 18:16:37 +08:00
82fcf1da64 fix pnpm lock 2025-09-22 18:09:00 +08:00
c662b95b80 Merge branch 'p284' into deploy/dev 2025-09-22 18:03:09 +08:00
a8c2a300f6 refactor: make memories API return MemoryBlock 2025-09-22 17:14:07 +08:00
d654d9d8b1 refactor: make ChatflowMemoryVariable.value JSON 2025-09-22 16:46:39 +08:00
eedc0ca6ea fix: add marshal app model to json 2025-09-22 16:25:53 +08:00
e4c3213978 fix: add marshal app model to json 2025-09-22 16:23:55 +08:00
520bc55da5 fix: add marshal app model to json 2025-09-22 16:23:42 +08:00
ebbb82178f fix: add marshal app model to json 2025-09-22 16:16:18 +08:00
d60a9ac63a fix: add marshal site model to json 2025-09-22 16:16:03 +08:00
f1486224e9 fix: add marshal app model to json 2025-09-22 16:14:28 +08:00
88b02d5a07 fix: add marshal site model to json 2025-09-22 15:47:23 +08:00
394b7d09b8 refactor: fix basedpyright/ruff errors 2025-09-22 15:17:19 +08:00
e9313b9c1b Merge branch 'main' into feat/memory-orchestration-be
# Conflicts:
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/workflow/constants.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/llm/node.py
#	api/models/workflow.py
2025-09-22 14:46:30 +08:00
5cf3d9e4d9 fix nginx config 2025-09-22 14:21:07 +08:00
41958f55cd fix CSP 2025-09-22 14:20:11 +08:00
600ad232e1 fix config 2025-09-22 14:20:11 +08:00
7a3825cfce fix docker config 2025-09-22 14:20:11 +08:00
9519653422 change default ws url 2025-09-22 14:20:11 +08:00
efa2307c73 change default ws url 2025-09-22 14:20:11 +08:00
068fa3d0e3 fix CI 2025-09-22 14:20:11 +08:00
13d8dbd542 fix CI 2025-09-22 14:20:08 +08:00
e59cc3311d add: trial api and trial table 2025-09-22 13:42:22 +08:00
4f4e94f753 add gevent-websocket 2025-09-22 13:33:00 +08:00
258970f489 chore(deps): bump boto3-stubs from 1.40.29 to 1.40.35 in /api (#26014)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-22 12:44:14 +08:00
781abe5f8f add: trial api and trial table 2025-09-22 10:44:08 +08:00
b442ba8b2b fix UserAvatarList background color 2025-09-19 12:07:07 +08:00
10e36d2355 add avatar on canvas node 2025-09-19 10:43:28 +08:00
13c53fedad add avatar display on node 2025-09-19 10:07:01 +08:00
4bda1bd884 open node panel not affect others 2025-09-18 17:42:02 +08:00
3abe7850d6 fix migration file 2025-09-18 16:30:40 +08:00
b50284d864 fix merge problem 2025-09-18 15:45:53 +08:00
81c6e52401 Merge remote-tracking branch 'origin/p254' into p284 2025-09-18 15:14:55 +08:00
847d257366 Merge branch 'p254' into p284 2025-09-18 14:50:59 +08:00
687662cf1f comment sync 2025-09-18 13:27:27 +08:00
6432d98469 improve the icon display on canvas 2025-09-18 11:49:43 +08:00
088ccf8b8d add UserAvatarList component 2025-09-18 09:47:07 +08:00
e8683bf957 fix comment cursor position 2025-09-18 09:17:45 +08:00
4653981b6b not display more icon when in edit mode 2025-09-17 20:45:54 +08:00
e2547413d3 fix edit input mouse pos 2025-09-17 20:40:59 +08:00
ea17f41b5b refactor reply code 2025-09-17 20:29:23 +08:00
29178d8adf can edit and delete a reply 2025-09-17 17:44:09 +08:00
7e86ead574 upgrade style 2025-09-17 16:41:10 +08:00
72debcb228 refactor mention input 2025-09-17 16:28:47 +08:00
72737dabc7 fix at can't click bug 2025-09-17 14:50:05 +08:00
f6e5cb4381 improve comment detail 2025-09-17 14:34:36 +08:00
ffad3b5fb1 comment detail window fix height 2025-09-17 13:45:56 +08:00
cba9fc3020 add comment reply 2025-09-17 12:50:42 +08:00
e776accaf3 add top operation buttons of comment detail 2025-09-17 10:45:15 +08:00
3eac26929a sync the comment panel and canvas 2025-09-17 09:13:31 +08:00
4d3adec738 click canvas icon display the active comment detail 2025-09-17 09:01:16 +08:00
ac5dd1f45a refactor: update MemoryApi(Resource) for version 2025-09-16 19:25:17 +08:00
3005cf3282 refactor: update MemoryApi(WebApiResource) for version 2025-09-16 19:12:08 +08:00
54b272206e refactor: add version param to get_session_memories and get_persistent_memories 2025-09-16 18:32:58 +08:00
89bed479e4 improve comment panel 2025-09-16 17:25:51 +08:00
fdd673a3a9 improve comments panel 2025-09-16 13:39:31 +08:00
22f6d285c7 fix comment cursor in panel incorrect 2025-09-16 10:20:12 +08:00
10aa16b471 add workflow comment panel 2025-09-16 09:51:12 +08:00
3d761a3189 refactor: make save_memory and get_memory_by_spec work on latest version 2025-09-15 19:28:22 +08:00
e3903f34e4 refactor: add version field to ChatflowMemoryVariable table 2025-09-15 19:27:41 +08:00
f4f055fb36 refactor: add version field to MemoryBlockWithVisibility 2025-09-15 19:27:17 +08:00
b3838581fd improve mention 2025-09-15 17:13:46 +08:00
affbe7ccdb can mention user in the create comment 2025-09-15 16:42:31 +08:00
8563ae5511 feat: add inference for VersionedMemory type when deserializing 2025-09-15 16:13:07 +08:00
2c765ccfae refactor: use VersionedMemoryVariable in ChatflowMemoryService.get_memory_by_spec 2025-09-15 15:47:02 +08:00
626e7b2211 refactor: use VersionedMemoryVariable in ChatflowMemoryService.save_memory 2025-09-15 15:41:33 +08:00
516b6b0fa8 refactor: use VersionedMemoryVariable in creation of WorkflowDraftVariable instead of StringVariable 2025-09-15 15:39:38 +08:00
613d086f1e refactor: give VersionedMemoryValue a default value 2025-09-15 15:38:20 +08:00
9e0630f012 fix: use correct description from spec 2025-09-15 15:30:08 +08:00
d6d9554954 fix: fix basedpyright errors 2025-09-15 14:20:30 +08:00
dd8577f832 comments display on canvas 2025-09-15 14:16:06 +08:00
2a532ab729 Merge branch 'main' into feat/memory-orchestration-be
# Conflicts:
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/prompt/entities/advanced_prompt_entities.py
#	api/core/variables/segments.py
2025-09-15 14:14:56 +08:00
03eef65b25 feat: add VersionedMemorySegment and VersionedMemoryVariable 2025-09-15 14:00:54 +08:00
ad07d63994 feat: add VersionedMemoryValueModel 2025-09-15 14:00:54 +08:00
8685f055ea fix: use model parameters from memory_spec in llm_generator 2025-09-15 14:00:54 +08:00
3b868a1cec feat: integrate VariablePool into memory update process 2025-09-15 14:00:53 +08:00
ab389eaa8e fix: fix ruff 2025-09-15 14:00:53 +08:00
008f778e8f fix: fix mypy 2025-09-15 14:00:53 +08:00
d7f5da5df4 display comments avatar on the canvas 2025-09-15 11:41:06 +08:00
9fda130b3a fix click comment once more then esc not work 2025-09-15 11:11:07 +08:00
72cdbdba0f fix chat input style 2025-09-15 09:20:06 +08:00
b92a153902 refactor code 2025-09-14 13:03:08 +08:00
9f2927979b fix comment cursor icon 2025-09-14 12:50:18 +08:00
75257232c3 add create comment frontend 2025-09-14 12:10:37 +08:00
1721314c62 add frontend comment service 2025-09-13 17:57:19 +08:00
fc230bcc59 add force update workflow to support restore 2025-09-12 16:27:12 +08:00
b4636ddf44 add leader restore workflow 2025-09-12 15:34:41 +08:00
b1140301a4 sync import dsl 2025-09-12 14:46:40 +08:00
58cd785da6 use const for cursor move config 2025-09-11 09:36:22 +08:00
2035186cd2 click avatar to follow user cursor position 2025-09-11 09:26:05 +08:00
53ba6aadff cursor pos transform to canvas 2025-09-11 09:07:03 +08:00
f091868b7c use new get avatar api 2025-09-10 15:15:43 +08:00
89bedae0d3 remove the test code for develop collaboration 2025-09-10 14:27:20 +08:00
c8acc48976 ruff format 2025-09-10 14:25:37 +08:00
21fee59b22 use new features update api 2025-09-10 14:24:38 +08:00
957a8253f8 change user list to conversation var panel left 2025-09-10 09:26:38 +08:00
d5fc3e7bed add new conversation vars update api 2025-09-10 09:24:22 +08:00
ab438b42da use new env variables update api 2025-09-10 09:07:55 +08:00
3867fece4a mcp server update 2025-09-09 15:01:38 +08:00
2b908d4fbe add app state update 2025-09-09 14:24:37 +08:00
8ff062ec8b change user default color 2025-09-09 10:20:02 +08:00
294fc41aec add redo undo manager of CRDT 2025-09-09 09:58:55 +08:00
684f7df158 node data use crdt data 2025-09-08 14:46:28 +08:00
c3287755e3 add request leader to sync graph 2025-09-08 09:00:20 +08:00
9f97f4d79e fix cursor style 2025-09-06 15:54:19 +08:00
34eb421649 add currentUserId is me 2025-09-06 12:27:54 +08:00
850b05573e add dropdown users list 2025-09-06 12:01:49 +08:00
6ec8bfdfee add mouse over avatar display username 2025-09-06 11:29:45 +08:00
81638c248e use one getUserColor func 2025-09-06 11:22:59 +08:00
2e11b1298e add online users avatar 2025-09-06 11:19:47 +08:00
20320f3a27 show online users on the canvas 2025-09-06 00:08:17 +08:00
4019c12d26 fix missing import 2025-09-05 22:20:07 +08:00
cf72184ce4 each browser tab session a ws connected obj 2025-09-05 22:19:16 +08:00
ca8d15bc64 add mention user list api 2025-08-31 13:42:59 +08:00
a91c897fd3 improve code 2025-08-31 00:43:34 +08:00
816bdf0320 add delete comment and reply 2025-08-31 00:28:01 +08:00
d4a6acbd99 add update reply 2025-08-30 23:49:27 +08:00
e421db4005 add resolve comment 2025-08-30 22:37:01 +08:00
6af168cb31 Merge branch 'main' into feat/memory-orchestration-be 2025-08-25 14:54:14 +08:00
29f56cf0cf chore: add database migration 2025-08-22 21:07:54 +08:00
11b6ea742d feat: add index for data tables 2025-08-22 20:43:49 +08:00
05d231ad33 fix: fix bugs check by Claude Code 2025-08-22 19:59:17 +08:00
48f3c69c69 fix: fix bugs check by Claude Code 2025-08-22 17:54:18 +08:00
9067c2a9c1 add update comment 2025-08-22 17:48:14 +08:00
8b68020453 refactor: refactor from ChatflowHistoryService and ChatflowMemoryService 2025-08-22 17:44:27 +08:00
9f7321ca1a add create reply 2025-08-22 17:33:47 +08:00
5fa01132b9 add create and list comment api 2025-08-22 16:47:08 +08:00
4d2fc66a8d feat: refactor: refactor from ChatflowHistoryService and ChatflowMemoryService 2025-08-22 15:33:45 +08:00
f72ed4898c refactor: refactor from ChatflowHistoryService and ChatflowMemoryService 2025-08-22 14:57:27 +08:00
e082b6d599 add workflow comment models 2025-08-22 11:28:26 +08:00
d44be2d835 add leader submit graph data 2025-08-21 17:53:39 +08:00
85a73181cc chore: run ruff 2025-08-21 17:23:24 +08:00
e31e4ab677 feat: add Service API for memory read and modify 2025-08-21 17:22:39 +08:00
0d95c2192e feat: add Web API for memory read and modify 2025-08-21 17:17:08 +08:00
7dc8557033 add Leader election 2025-08-21 16:17:16 +08:00
1fa8b26e55 feat: fetch memory block from WorkflowDraftVariable when debugging single node 2025-08-21 15:17:25 +08:00
4b085d46f6 feat: update variable pool when update memory 2025-08-21 15:15:23 +08:00
72037a1865 improve cursors logic 2025-08-21 14:27:41 +08:00
635c4ed4ce feat: add memory update check in AdvancedChatAppRunner 2025-08-21 14:24:17 +08:00
7ffcf8dd6f feat: add memory update check in AdvancedChatAppRunner 2025-08-21 13:27:00 +08:00
97cd21d3be feat: sync conversation history with chatflow_ tables in chatflow 2025-08-21 13:03:19 +08:00
a13cb7e1c5 feat: init memory block for VariablePool in AdvancedChatAppRunner.run 2025-08-21 11:40:30 +08:00
7b602e9003 feat: wait for sync memory update in AdvancedChatAppRunner.run 2025-08-21 11:32:27 +08:00
5a26ebec8f feat: add _fetch_memory_blocks for AdvancedChatAppRunner 2025-08-21 11:28:47 +08:00
8341b8b1c1 feat: add MemoryBlock config to LLM's memory config 2025-08-20 19:53:44 +08:00
bbb640c9a2 feat: add MemoryBlock to VariablePool 2025-08-20 19:45:18 +08:00
0c97bbf137 chore: run ruff 2025-08-20 19:12:34 +08:00
45fddc70d5 feat: add ChatflowHistoryService and ChatflowMemoryService 2025-08-20 19:11:12 +08:00
f977dc410a feat: add MemorySyncTimeoutError 2025-08-20 17:45:53 +08:00
d535818505 feat: add new_memory_block_variable for WorkflowDraftVariable 2025-08-20 17:41:45 +08:00
fcf4e1f37d feat: add MEMORY_BLOCK_VARIABLE_NODE_ID 2025-08-20 17:41:13 +08:00
38130c8502 feat: add memory_blocks property to workflow's graph for memory block configuration 2025-08-20 17:19:48 +08:00
f284c91988 feat: add data tables for chatflow memory 2025-08-20 17:16:54 +08:00
584b2cefa3 feat: add pydantic models for memory 2025-08-20 17:03:15 +08:00
42091b4a79 feat: add MEMORY_BLOCK in DraftVariableType 2025-08-20 16:51:07 +08:00
2d1621c43d add leader but not review 2025-08-08 14:54:18 +08:00
d1a5db3310 rm useCollaborativeCursors compoent 2025-08-07 18:03:12 +08:00
ad8fd8fecc clone the node to avoid loro recursive 2025-08-07 17:45:38 +08:00
be74b76079 refactor websocket init 2025-08-07 17:31:12 +08:00
dd64af728f refactor the cursors component 2025-08-07 14:29:23 +08:00
e43b46786d refactor all the frontend code 2025-08-07 10:58:53 +08:00
3f3b37b843 refactor to support mutli websocket connections 2025-08-06 17:05:39 +08:00
2ecf9f6ddf add features collaboration 2025-08-06 10:58:32 +08:00
48c069fe68 support env vars collaborate 2025-08-05 15:22:22 +08:00
9c5c597c85 support empty collaboration event data 2025-08-05 15:21:41 +08:00
c2eec8545d collaborate conversation vars 2025-08-05 14:24:51 +08:00
2395d4be26 fix imported updates also broadcast to other clients 2025-08-05 10:21:22 +08:00
9455476705 handle edge delete 2025-08-04 14:17:59 +08:00
494e223706 some operations don't need to broadcast 2025-08-03 14:18:48 +08:00
348fd18230 refactor collaboration 2025-08-03 13:34:07 +08:00
7233b4de55 the initial data to collaboration store 2025-07-31 16:27:01 +08:00
af6df05685 add setNodes and setEdges of collaboration store 2025-07-31 15:25:50 +08:00
965b65db6e use loro for crdt data 2025-07-31 14:02:53 +08:00
4cc01c8aa8 try a lot for yjs, but update data still not work... 2025-07-30 14:36:29 +08:00
41372168b6 refactor code 2025-07-23 10:04:16 +08:00
f4438b0a08 support mouse display 2025-07-22 18:08:35 +08:00
897c842637 ruff format 2025-07-21 16:13:04 +08:00
ee86ceb906 fix gunicorn gvent 2025-07-21 16:09:51 +08:00
e298732499 refactor code 2025-07-21 16:07:22 +08:00
4081937e22 migrate to python-socketio 2025-07-21 14:57:28 +08:00
f9aedb2118 add collaborate event 2025-07-21 11:10:23 +08:00
74b4719af8 support broadcast online users 2025-07-18 15:02:34 +08:00
2f35cc9188 add online users backend api and frontend submit cursor pos 2025-07-18 11:17:08 +08:00
2f966d8c38 fix websocket auth 2025-07-17 17:16:52 +08:00
b0868d9136 fix websocket auth 2025-07-17 17:16:38 +08:00
37440e9416 ruff format 2025-07-17 15:37:13 +08:00
0d7d27ec0b establish websocket connection 2025-07-17 15:36:50 +08:00
164 changed files with 13844 additions and 3532 deletions

View File

@ -1,3 +1,4 @@
import os
import sys
@ -8,10 +9,16 @@ def is_db_command():
# create app
celery = None
flask_app = None
socketio_app = None
if is_db_command():
from app_factory import create_migrations_app
app = create_migrations_app()
socketio_app = app
flask_app = app
else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
@ -33,8 +40,15 @@ else:
from app_factory import create_app
app = create_app()
celery = app.extensions["celery"]
socketio_app, flask_app = create_app()
app = flask_app
celery = flask_app.extensions["celery"]
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", 5001))
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
server.serve_forever()

View File

@ -31,14 +31,22 @@ def create_flask_app_with_configs() -> DifyApp:
return dify_app
def create_app() -> DifyApp:
def create_app() -> tuple[any, DifyApp]:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
import socketio
from extensions.ext_socketio import sio
sio.app = app
socketio_app = socketio.WSGIApp(sio, app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return app
return socketio_app, app
def initialize_extensions(app: DifyApp):

View File

@ -836,6 +836,16 @@ class MailConfig(BaseSettings):
default=None,
)
ENABLE_TRIAL_APP: bool = Field(
description="Enable trial app",
default=False,
)
ENABLE_EXPLORE_BANNER: bool = Field(
description="Enable explore banner",
default=False,
)
class RagEtlConfig(BaseSettings):
"""

View File

@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
default="",
)
HOSTED_POOL_CREDITS: int = Field(
description="Pool credits for hosted service",
default=200,
)
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
@ -70,11 +75,6 @@ class HostedOpenAiConfig(BaseSettings):
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted OpenAI service",
default=False,
@ -98,6 +98,129 @@ class HostedOpenAiConfig(BaseSettings):
)
class HostedGeminiConfig(BaseSettings):
"""
Configuration for fetching Gemini service
"""
HOSTED_GEMINI_API_KEY: str | None = Field(
description="API key for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_API_BASE: str | None = Field(
description="Base URL for hosted Gemini API",
default=None,
)
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Gemini service",
default=False,
)
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted gemini service",
default=False,
)
HOSTED_GEMINI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
class HostedXAIConfig(BaseSettings):
"""
Configuration for fetching XAI service
"""
HOSTED_XAI_API_KEY: str | None = Field(
description="API key for hosted XAI service",
default=None,
)
HOSTED_XAI_API_BASE: str | None = Field(
description="Base URL for hosted XAI API",
default=None,
)
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted XAI service",
default=None,
)
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted XAI service",
default=False,
)
HOSTED_XAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
HOSTED_XAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_XAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedDeepseekConfig(BaseSettings):
"""
Configuration for fetching Deepseek service
"""
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
description="API key for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
description="Base URL for hosted Deepseek API",
default=None,
)
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="deepseek-chat,deepseek-reasoner",
)
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedAzureOpenAiConfig(BaseSettings):
"""
Configuration for hosted Azure OpenAI service
@ -144,16 +267,32 @@ class HostedAnthropicConfig(BaseSettings):
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
class HostedMinmaxConfig(BaseSettings):
"""
@ -250,5 +389,8 @@ class HostedServiceConfig(
HostedModerationConfig,
# credit config
HostedCreditConfig,
HostedGeminiConfig,
HostedXAIConfig,
HostedDeepseekConfig,
):
pass

View File

@ -58,11 +58,13 @@ from .app import (
mcp_server,
message,
model_config,
online_user,
ops_trace,
site,
statistic,
workflow,
workflow_app_log,
workflow_comment,
workflow_draft_variable,
workflow_run,
workflow_statistic,
@ -106,10 +108,12 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
banner,
installed_app,
parameter,
recommended_app,
saved_message,
trial,
)
# Import tag controllers
@ -143,6 +147,7 @@ __all__ = [
"apikey",
"app",
"audio",
"banner",
"billing",
"bp",
"completion",
@ -196,6 +201,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trial",
"version",
"website",
"workflow",

View File

@ -15,7 +15,7 @@ from constants.languages import supported_language
from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
def admin_required(view: Callable[P, R]):
@ -61,6 +61,8 @@ class InsertExploreAppListApi(Resource):
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
"can_trial": fields.Boolean(required=True, description="Can trial"),
"trial_limit": fields.Integer(required=True, description="Trial limit"),
},
)
)
@ -79,6 +81,8 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
parser.add_argument("can_trial", type=bool, required=True, nullable=False, location="json")
parser.add_argument("trial_limit", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
@ -115,6 +119,20 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
if args["can_trial"]:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == args["app_id"])
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=args["app_id"],
tenant_id=app.tenant_id,
trial_limit=args["trial_limit"],
)
)
else:
trial_app.trial_limit = args["trial_limit"]
app.is_public = True
db.session.commit()
@ -129,6 +147,20 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = args["category"]
recommended_app.position = args["position"]
if args["can_trial"]:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == args["app_id"])
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=args["app_id"],
tenant_id=app.tenant_id,
trial_limit=args["trial_limit"],
)
)
else:
trial_app.trial_limit = args["trial_limit"]
app.is_public = True
db.session.commit()
@ -174,7 +206,67 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
trial_app = session.execute(
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
).scalar_one_or_none()
if trial_app:
session.delete(trial_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/admin/insert-explore-banner")
class InsertExploreBanner(Resource):
@api.doc("insert_explore_banner")
@api.doc(description="Insert an explore banner")
@api.expect(
api.model(
"InsertExploreBannerRequest",
{
"content": fields.String(required=True, description="Banner content"),
"link": fields.String(required=True, description="Banner link"),
"sort": fields.Integer(required=True, description="Banner sort"),
},
)
)
@api.response(200, "Banner inserted successfully")
@admin_required
@only_edition_cloud
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("link", type=str, required=True, nullable=False, location="json")
parser.add_argument("sort", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
banner = ExporleBanner(
content=args["content"],
link=args["link"],
sort=args["sort"],
)
db.session.add(banner)
db.session.commit()
return {"result": "success"}, 200
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
class DeleteExploreBanner(Resource):
@api.doc("delete_explore_banner")
@api.doc(description="Delete an explore banner")
@api.response(204, "Banner deleted successfully")
@admin_required
@only_edition_cloud
def delete(self, banner_id):
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
if not banner:
raise NotFound(f"Banner '{banner_id}' is not found")
db.session.delete(banner)
db.session.commit()
return {"result": "success"}, 204

View File

@ -0,0 +1,291 @@
import json
import time
from extensions.ext_redis import redis_client
from extensions.ext_socketio import sio
from libs.passport import PassportService
from services.account_service import AccountService
@sio.on("connect")
def socket_connect(sid, environ, auth):
"""
WebSocket connect event, do authentication here.
"""
token = None
if auth and isinstance(auth, dict):
token = auth.get("token")
if not token:
return False
try:
decoded = PassportService().verify(token)
user_id = decoded.get("user_id")
if not user_id:
return False
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
if not user:
return False
sio.save_session(sid, {"user_id": user.id, "username": user.name, "avatar": user.avatar})
return True
except Exception:
return False
@sio.on("user_connect")
def handle_user_connect(sid, data):
"""
Handle user connect event. Each session (tab) is treated as an independent collaborator.
"""
workflow_id = data.get("workflow_id")
if not workflow_id:
return {"msg": "workflow_id is required"}, 400
session = sio.get_session(sid)
user_id = session.get("user_id")
if not user_id:
return {"msg": "unauthorized"}, 401
# Each session is stored independently with sid as key
session_info = {
"user_id": user_id,
"username": session.get("username", "Unknown"),
"avatar": session.get("avatar", None),
"sid": sid,
"connected_at": int(time.time()), # Add timestamp to differentiate tabs
}
# Store session info with sid as key
redis_client.hset(f"workflow_online_users:{workflow_id}", sid, json.dumps(session_info))
redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": user_id}))
# Leader election: first session becomes the leader
leader_sid = get_or_set_leader(workflow_id, sid)
is_leader = leader_sid == sid
sio.enter_room(sid, workflow_id)
broadcast_online_users(workflow_id)
# Notify this session of their leader status
sio.emit("status", {"isLeader": is_leader}, room=sid)
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
@sio.on("disconnect")
def handle_disconnect(sid):
"""
Handle session disconnect event. Remove the specific session from online users.
"""
mapping = redis_client.get(f"ws_sid_map:{sid}")
if mapping:
data = json.loads(mapping)
workflow_id = data["workflow_id"]
# Remove this specific session
redis_client.hdel(f"workflow_online_users:{workflow_id}", sid)
redis_client.delete(f"ws_sid_map:{sid}")
# Handle leader re-election if the leader session disconnected
handle_leader_disconnect(workflow_id, sid)
broadcast_online_users(workflow_id)
def _clear_session_state(workflow_id: str, sid: str) -> None:
redis_client.hdel(f"workflow_online_users:{workflow_id}", sid)
redis_client.delete(f"ws_sid_map:{sid}")
def _is_session_active(workflow_id: str, sid: str) -> bool:
if not sid:
return False
try:
if not sio.manager.is_connected(sid, "/"):
return False
except AttributeError:
return False
if not redis_client.hexists(f"workflow_online_users:{workflow_id}", sid):
return False
if not redis_client.exists(f"ws_sid_map:{sid}"):
return False
return True
def get_or_set_leader(workflow_id: str, sid: str) -> str:
"""
Get current leader session or set this session as leader if no valid leader exists.
Returns the leader session id (sid).
"""
leader_key = f"workflow_leader:{workflow_id}"
raw_leader = redis_client.get(leader_key)
current_leader = raw_leader.decode("utf-8") if isinstance(raw_leader, bytes) else raw_leader
leader_replaced = False
if current_leader and not _is_session_active(workflow_id, current_leader):
_clear_session_state(workflow_id, current_leader)
redis_client.delete(leader_key)
current_leader = None
leader_replaced = True
if not current_leader:
redis_client.set(leader_key, sid, ex=3600) # Expire in 1 hour
if leader_replaced:
broadcast_leader_change(workflow_id, sid)
return sid
return current_leader
def handle_leader_disconnect(workflow_id, disconnected_sid):
"""
Handle leader re-election when a session disconnects.
If the disconnected session was the leader, elect a new leader from remaining sessions.
"""
leader_key = f"workflow_leader:{workflow_id}"
current_leader = redis_client.get(leader_key)
if current_leader:
current_leader = current_leader.decode("utf-8") if isinstance(current_leader, bytes) else current_leader
if current_leader == disconnected_sid:
# Leader session disconnected, elect a new leader
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
if sessions_json:
# Get the first remaining session as new leader
new_leader_sid = list(sessions_json.keys())[0]
if isinstance(new_leader_sid, bytes):
new_leader_sid = new_leader_sid.decode("utf-8")
redis_client.set(leader_key, new_leader_sid, ex=3600)
# Notify all sessions about the new leader
broadcast_leader_change(workflow_id, new_leader_sid)
else:
# No sessions left, remove leader
redis_client.delete(leader_key)
def broadcast_leader_change(workflow_id, new_leader_sid):
"""
Broadcast leader change to all sessions in the workflow.
"""
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
for sid, session_info_json in sessions_json.items():
try:
sid_str = sid.decode("utf-8") if isinstance(sid, bytes) else sid
is_leader = sid_str == new_leader_sid
# Emit to each session whether they are the new leader
sio.emit("status", {"isLeader": is_leader}, room=sid_str)
except Exception:
continue
def get_current_leader(workflow_id):
"""
Get the current leader for a workflow.
"""
leader_key = f"workflow_leader:{workflow_id}"
leader = redis_client.get(leader_key)
return leader.decode("utf-8") if leader and isinstance(leader, bytes) else leader
def broadcast_online_users(workflow_id):
"""
Broadcast online users to the workflow room.
Each session is shown as a separate user (even if same person has multiple tabs).
"""
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
users = []
for sid, session_info_json in sessions_json.items():
try:
session_info = json.loads(session_info_json)
# Each session appears as a separate "user" in the UI
users.append(
{
"user_id": session_info["user_id"],
"username": session_info["username"],
"avatar": session_info.get("avatar"),
"sid": session_info["sid"],
"connected_at": session_info.get("connected_at"),
}
)
except Exception:
continue
# Sort by connection time to maintain consistent order
users.sort(key=lambda x: x.get("connected_at") or 0)
# Get current leader session
leader_sid = get_current_leader(workflow_id)
sio.emit("online_users", {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, room=workflow_id)
@sio.on("collaboration_event")
def handle_collaboration_event(sid, data):
"""
Handle general collaboration events, include:
1. mouseMove
2. varsAndFeaturesUpdate
3. syncRequest(ask leader to update graph)
4. appStateUpdate
5. mcpServerUpdate
"""
mapping = redis_client.get(f"ws_sid_map:{sid}")
if not mapping:
return {"msg": "unauthorized"}, 401
mapping_data = json.loads(mapping)
workflow_id = mapping_data["workflow_id"]
user_id = mapping_data["user_id"]
event_type = data.get("type")
event_data = data.get("data")
timestamp = data.get("timestamp", int(time.time()))
if not event_type:
return {"msg": "invalid event type"}, 400
sio.emit(
"collaboration_update",
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
room=workflow_id,
skip_sid=sid,
)
return {"msg": "event_broadcasted"}
@sio.on("graph_event")
def handle_graph_event(sid, data):
"""
Handle graph events - simple broadcast relay.
"""
mapping = redis_client.get(f"ws_sid_map:{sid}")
if not mapping:
return {"msg": "unauthorized"}, 401
mapping_data = json.loads(mapping)
workflow_id = mapping_data["workflow_id"]
sio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
return {"msg": "graph_update_broadcasted"}

View File

@ -5,6 +5,7 @@ from typing import cast
from flask import abort, request
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
from pydantic_core import ValidationError
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -21,7 +22,9 @@ from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory, variable_factory
from fields.online_user_fields import online_user_list_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@ -103,6 +106,7 @@ class DraftWorkflowApi(Resource):
"hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
"memory_blocks": fields.List(fields.Raw, description="Memory blocks"),
},
)
)
@ -127,6 +131,8 @@ class DraftWorkflowApi(Resource):
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("force_upload", type=bool, required=False, default=False, location="json")
parser.add_argument("memory_blocks", type=list, required=False, location="json")
args = parser.parse_args()
elif "text/plain" in content_type:
try:
@ -143,6 +149,8 @@ class DraftWorkflowApi(Resource):
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"memory_blocks": data.get("memory_blocks"),
"force_upload": data.get("force_upload", False),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
@ -163,6 +171,11 @@ class DraftWorkflowApi(Resource):
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
memory_blocks_list = args.get("memory_blocks") or []
from core.memory.entities import MemoryBlockSpec
memory_blocks = [
MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list
]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args["graph"],
@ -171,9 +184,13 @@ class DraftWorkflowApi(Resource):
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
force_upload=args.get("force_upload", False),
memory_blocks=memory_blocks,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
except ValidationError as e:
return {"message": str(e)}, 400
return {
"result": "success",
@ -796,6 +813,45 @@ class ConvertToWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
class WorkflowConfigApi(Resource):
"""Resource for workflow configuration."""
@api.doc("get_workflow_config")
@api.doc(description="Get workflow configuration")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Workflow configuration retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
class WorkflowFeaturesApi(Resource):
"""Update draft workflow features."""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
parser = reqparse.RequestParser()
parser.add_argument("features", type=dict, required=True, location="json")
args = parser.parse_args()
features = args.get("features")
# Update draft workflow features
workflow_service = WorkflowService()
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@api.doc("get_all_published_workflows")
@ -985,3 +1041,105 @@ class DraftWorkflowNodeLastRunApi(Resource):
if node_exec is None:
raise NotFound("last run not found")
return node_exec
class WorkflowOnlineUsersApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(online_user_list_fields)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("workflow_ids", type=str, required=True, location="args")
args = parser.parse_args()
workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")]
results = []
for workflow_id in workflow_ids:
users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
users = []
for _, user_info_json in users_json.items():
try:
users.append(json.loads(user_info_json))
except Exception:
continue
results.append({"workflow_id": workflow_id, "users": users})
return {"data": results}
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
)
api.add_resource(
WorkflowConfigApi,
"/apps/<uuid:app_id>/workflows/draft/config",
)
api.add_resource(
WorkflowFeaturesApi,
"/apps/<uuid:app_id>/workflows/draft/features",
)
api.add_resource(
AdvancedChatDraftWorkflowRunApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
)
api.add_resource(
DraftWorkflowRunApi,
"/apps/<uuid:app_id>/workflows/draft/run",
)
api.add_resource(
WorkflowTaskStopApi,
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
DraftWorkflowNodeRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedWorkflowApi,
"/apps/<uuid:app_id>/workflows/publish",
)
api.add_resource(
PublishedAllWorkflowApi,
"/apps/<uuid:app_id>/workflows",
)
api.add_resource(
DefaultBlockConfigsApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultBlockConfigApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
ConvertToWorkflowApi,
"/apps/<uuid:app_id>/convert-to-workflow",
)
api.add_resource(
WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
)
api.add_resource(
DraftWorkflowNodeLastRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
)
api.add_resource(WorkflowOnlineUsersApi, "/apps/workflows/online-users")

View File

@ -0,0 +1,240 @@
import logging
from flask_restx import Resource, fields, marshal_with, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.member_fields import account_with_role_fields
from fields.workflow_comment_fields import (
workflow_comment_basic_fields,
workflow_comment_create_fields,
workflow_comment_detail_fields,
workflow_comment_reply_create_fields,
workflow_comment_reply_update_fields,
workflow_comment_resolve_fields,
workflow_comment_update_fields,
)
from libs.login import current_user, login_required
from models import App
from services.account_service import TenantService
from services.workflow_comment_service import WorkflowCommentService
logger = logging.getLogger(__name__)
class WorkflowCommentListApi(Resource):
"""API for listing and creating workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_basic_fields, envelope="data")
def get(self, app_model: App):
"""Get all comments for a workflow."""
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
return comments
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_create_fields)
def post(self, app_model: App):
"""Create a new workflow comment."""
parser = reqparse.RequestParser()
parser.add_argument("position_x", type=float, required=True, location="json")
parser.add_argument("position_y", type=float, required=True, location="json")
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.create_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
created_by=current_user.id,
content=args.content,
position_x=args.position_x,
position_y=args.position_y,
mentioned_user_ids=args.mentioned_user_ids,
)
return result, 201
class WorkflowCommentDetailApi(Resource):
"""API for managing individual workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_detail_fields)
def get(self, app_model: App, comment_id: str):
"""Get a specific workflow comment."""
comment = WorkflowCommentService.get_comment(
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
)
return comment
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_update_fields)
def put(self, app_model: App, comment_id: str):
"""Update a workflow comment."""
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("position_x", type=float, required=False, location="json")
parser.add_argument("position_y", type=float, required=False, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.update_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
content=args.content,
position_x=args.position_x,
position_y=args.position_y,
mentioned_user_ids=args.mentioned_user_ids,
)
return result
@login_required
@setup_required
@account_initialization_required
@get_app_model
def delete(self, app_model: App, comment_id: str):
"""Delete a workflow comment."""
WorkflowCommentService.delete_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return {"result": "success"}, 204
class WorkflowCommentResolveApi(Resource):
"""API for resolving and reopening workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_resolve_fields)
def post(self, app_model: App, comment_id: str):
"""Resolve a workflow comment."""
comment = WorkflowCommentService.resolve_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return comment
class WorkflowCommentReplyApi(Resource):
"""API for managing comment replies."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_reply_create_fields)
def post(self, app_model: App, comment_id: str):
"""Add a reply to a workflow comment."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.create_reply(
comment_id=comment_id,
content=args.content,
created_by=current_user.id,
mentioned_user_ids=args.mentioned_user_ids,
)
return result, 201
class WorkflowCommentReplyDetailApi(Resource):
"""API for managing individual comment replies."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_reply_update_fields)
def put(self, app_model: App, comment_id: str, reply_id: str):
"""Update a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
reply = WorkflowCommentService.update_reply(
reply_id=reply_id, user_id=current_user.id, content=args.content, mentioned_user_ids=args.mentioned_user_ids
)
return reply
@login_required
@setup_required
@account_initialization_required
@get_app_model
def delete(self, app_model: App, comment_id: str, reply_id: str):
"""Delete a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
return {"result": "success"}, 204
class WorkflowCommentMentionUsersApi(Resource):
"""API for getting mentionable users for workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with({"users": fields.List(fields.Nested(account_with_role_fields))})
def get(self, app_model: App):
"""Get all users in current tenant for mentions."""
members = TenantService.get_tenant_members(current_user.current_tenant)
return {"users": members}
# Register API routes
api.add_resource(WorkflowCommentListApi, "/apps/<uuid:app_id>/workflow/comments")
api.add_resource(WorkflowCommentDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
api.add_resource(WorkflowCommentResolveApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
api.add_resource(WorkflowCommentReplyApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
api.add_resource(
WorkflowCommentReplyDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>"
)
api.add_resource(WorkflowCommentMentionUsersApi, "/apps/<uuid:app_id>/workflow/comments/mention-users")

View File

@ -19,8 +19,8 @@ from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories import variable_factory
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode
from models.account import Account
@ -353,7 +353,7 @@ class VariableApi(Resource):
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@ -446,8 +446,35 @@ class ConversationVariableCollectionApi(Resource):
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
def post(self, app_model: App):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("conversation_variables", type=list, required=True, location="json")
args = parser.parse_args()
workflow_service = WorkflowService()
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow_service.update_draft_workflow_conversation_variables(
app_model=app_model,
account=current_user,
conversation_variables=conversation_variables,
)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource):
@api.doc("get_system_variables")
@api.doc(description="Get system variables for workflow")
@ -497,3 +524,44 @@ class EnvironmentVariableCollectionApi(Resource):
)
return {"items": env_vars_list}
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("environment_variables", type=list, required=True, location="json")
args = parser.parse_args()
workflow_service = WorkflowService()
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
workflow_service.update_draft_workflow_environment_variables(
app_model=app_model,
account=current_user,
environment_variables=environment_variables,
)
return {"result": "success"}
api.add_resource(
WorkflowVariableCollectionApi,
"/apps/<uuid:app_id>/workflows/draft/variables",
)
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")

View File

@ -0,0 +1,34 @@
from flask_restx import Resource
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.model import ExporleBanner
class BannerApi(Resource):
"""Resource for banner list."""
@explore_banner_enabled
def get(self):
"""Get banner list."""
banners = (
db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled").order_by(ExporleBanner.sort).all()
)
# Convert banners to serializable format
result = []
for banner in banners:
banner_data = {
"content": banner.content, # Already parsed as JSON by SQLAlchemy
"link": banner.link,
"sort": banner.sort,
"status": banner.status,
"created_at": banner.created_at.isoformat() if banner.created_at else None,
}
result.append(banner_data)
return result
api.add_resource(BannerApi, "/explore/banners")

View File

@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
class TrialAppNotAllowed(BaseHTTPException):
"""*403* `Trial App Not Allowed`
Raise if the user has reached the trial app limit.
"""
error_code = "trial_app_not_allowed"
code = 403
description = "the app is not allowed to be trial."
class TrialAppLimitExceeded(BaseHTTPException):
"""*403* `Trial App Limit Exceeded`
Raise if the user has exceeded the trial app limit.
"""
error_code = "trial_app_limit_exceeded"
code = 403
description = "The user has exceeded the trial app limit."

View File

@ -27,6 +27,7 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_list_fields = {

View File

@ -0,0 +1,375 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common import fields
from controllers.common.fields import build_site_model
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
)
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
from controllers.service_api import service_api_ns
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.account import TenantStatus
from models.model import AppMode, Site
from services.app_generate_service import AppGenerateService
from services.app_service import AppService
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
class TrialChatApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
class TrialChatAudioApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
file = request.files["file"]
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialChatTextApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialCompletionApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model
@service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model):
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
assert app_model.tenant
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model
@marshal_with(fields.parameters_fields)
def get(self, app_model):
"""Retrieve app parameters."""
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class AppApi(Resource):
@trial_feature_enable
@get_app_model
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
app_model = app_service.get_app(app_model)
return app_model
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
api.add_resource(
TrialMessageSuggestedQuestionApi,
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="trial_app_suggested_question",
)
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")

View File

@ -2,15 +2,16 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from models import InstalledApp
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -74,6 +75,59 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
if trial_app is None:
raise TrialAppNotAllowed()
app = trial_app.app
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:
raise TrialAppLimitExceeded()
return view(app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_trial_app:
abort(403, "Trial app feature is not enabled.")
return view(*args, **kwargs)
return decorated
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_explore_banner:
abort(403, "Explore banner feature is not enabled.")
return view(*args, **kwargs)
return decorated
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@ -83,3 +137,13 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
class TrialAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [
trial_app_required,
account_initialization_required,
login_required,
]

View File

@ -33,6 +33,7 @@ from controllers.console.wraps import (
only_edition_cloud,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
@ -135,6 +136,17 @@ class AccountNameApi(Resource):
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="args")
args = parser.parse_args()
avatar_url = file_helpers.get_signed_file_url(args["avatar"])
return {"avatar_url": avatar_url}
@setup_required
@login_required
@account_initialization_required

View File

@ -51,6 +51,8 @@ tenant_fields = {
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
"trial_credits": fields.Integer,
"trial_credits_used": fields.Integer,
}
tenants_fields = {

View File

@ -19,6 +19,7 @@ from .app import (
annotation,
app,
audio,
chatflow_memory,
completion,
conversation,
file,
@ -40,6 +41,7 @@ __all__ = [
"annotation",
"app",
"audio",
"chatflow_memory",
"completion",
"conversation",
"dataset",

View File

@ -0,0 +1,124 @@
from flask_restx import Resource, reqparse
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.memory.entities import MemoryBlock, MemoryCreatedBy
from core.workflow.entities.variable_pool import VariablePool
from models import App, EndUser
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
class MemoryListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def get(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
parser.add_argument("memory_id", required=False, type=str | None, default=None)
parser.add_argument("version", required=False, type=int | None, default=None)
args = parser.parse_args()
conversation_id: str | None = args.get("conversation_id")
memory_id = args.get("memory_id")
version = args.get("version")
if conversation_id:
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
result = [*result, *session_memories]
else:
result = ChatflowMemoryService.get_persistent_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
version
)
if memory_id:
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
class MemoryEditApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def put(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=True)
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
parser.add_argument('node_id', type=str | None, required=False, default=None)
parser.add_argument('update', type=str, required=True)
args = parser.parse_args()
workflow = WorkflowService().get_published_workflow(app_model)
update = args.get("update")
conversation_id = args.get("conversation_id")
node_id = args.get("node_id")
if not isinstance(update, str):
return {'error': 'Invalid update'}, 400
if not workflow:
return {'error': 'Workflow not found'}, 404
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
if not memory_spec:
return {'error': 'Memory not found'}, 404
# First get existing memory
existing_memory = ChatflowMemoryService.get_memory_by_spec(
spec=memory_spec,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
created_by=MemoryCreatedBy(end_user_id=end_user.id),
conversation_id=conversation_id,
node_id=node_id,
is_draft=False
)
# Create updated memory instance with incremented version
updated_memory = MemoryBlock(
spec=existing_memory.spec,
tenant_id=existing_memory.tenant_id,
app_id=existing_memory.app_id,
conversation_id=existing_memory.conversation_id,
node_id=existing_memory.node_id,
value=update, # New value
version=existing_memory.version + 1, # Increment version for update
edited_by_user=True,
created_by=existing_memory.created_by,
)
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
return '', 204
class MemoryDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def delete(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=False, default=None)
args = parser.parse_args()
memory_id = args.get('id')
if memory_id:
ChatflowMemoryService.delete_memory(
app_model,
memory_id,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 204
else:
ChatflowMemoryService.delete_all_user_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 200
api.add_resource(MemoryListApi, '/memories')
api.add_resource(MemoryEditApi, '/memory-edit')
api.add_resource(MemoryDeleteApi, '/memories')

View File

@ -18,6 +18,7 @@ web_ns = Namespace("web", description="Web application API operations", path="/"
from . import (
app,
audio,
chatflow_memory,
completion,
conversation,
feature,
@ -39,6 +40,7 @@ __all__ = [
"app",
"audio",
"bp",
"chatflow_memory",
"completion",
"conversation",
"feature",

View File

@ -0,0 +1,123 @@
from flask_restx import reqparse
from controllers.web import api
from controllers.web.wraps import WebApiResource
from core.memory.entities import MemoryBlock, MemoryCreatedBy
from core.workflow.entities.variable_pool import VariablePool
from models import App, EndUser
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
class MemoryListApi(WebApiResource):
def get(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
parser.add_argument("memory_id", required=False, type=str | None, default=None)
parser.add_argument("version", required=False, type=int | None, default=None)
args = parser.parse_args()
conversation_id: str | None = args.get("conversation_id")
memory_id = args.get("memory_id")
version = args.get("version")
if conversation_id:
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
conversation_id,
version
)
result = [*result, *session_memories]
else:
result = ChatflowMemoryService.get_persistent_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id),
version
)
if memory_id:
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
class MemoryEditApi(WebApiResource):
def put(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=True)
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
parser.add_argument('node_id', type=str | None, required=False, default=None)
parser.add_argument('update', type=str, required=True)
args = parser.parse_args()
workflow = WorkflowService().get_published_workflow(app_model)
update = args.get("update")
conversation_id = args.get("conversation_id")
node_id = args.get("node_id")
if not isinstance(update, str):
return {'error': 'Update must be a string'}, 400
if not workflow:
return {'error': 'Workflow not found'}, 404
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
if not memory_spec:
return {'error': 'Memory not found'}, 404
if not memory_spec.end_user_editable:
return {'error': 'Memory not editable'}, 403
# First get existing memory
existing_memory = ChatflowMemoryService.get_memory_by_spec(
spec=memory_spec,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
created_by=MemoryCreatedBy(end_user_id=end_user.id),
conversation_id=conversation_id,
node_id=node_id,
is_draft=False
)
# Create updated memory instance with incremented version
updated_memory = MemoryBlock(
spec=existing_memory.spec,
tenant_id=existing_memory.tenant_id,
app_id=existing_memory.app_id,
conversation_id=existing_memory.conversation_id,
node_id=existing_memory.node_id,
value=update, # New value
version=existing_memory.version + 1, # Increment version for update
edited_by_user=True,
created_by=existing_memory.created_by,
)
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
return '', 204
class MemoryDeleteApi(WebApiResource):
def delete(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('id', type=str, required=False, default=None)
args = parser.parse_args()
memory_id = args.get('id')
if memory_id:
ChatflowMemoryService.delete_memory(
app_model,
memory_id,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 204
else:
ChatflowMemoryService.delete_all_user_memories(
app_model,
MemoryCreatedBy(end_user_id=end_user.id)
)
return '', 200
api.add_resource(MemoryListApi, '/memories')
api.add_resource(MemoryEditApi, '/memory-edit')
api.add_resource(MemoryDeleteApi, '/memories')

View File

@ -1,10 +1,11 @@
import logging
import time
from collections.abc import Mapping
from collections.abc import Mapping, MutableMapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -20,11 +21,14 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_events import GraphRunSucceededEvent
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@ -34,6 +38,8 @@ from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.chatflow_history_service import ChatflowHistoryService
from services.chatflow_memory_service import ChatflowMemoryService
logger = logging.getLogger(__name__)
@ -70,6 +76,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._app = app
def run(self):
ChatflowMemoryService.wait_for_sync_memory_completion(
workflow=self._workflow,
conversation_id=self.conversation.id
)
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
@ -133,6 +144,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=conversation_variables,
memory_blocks=self._fetch_memory_blocks(),
)
# init graph
@ -177,6 +189,31 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
for event in generator:
self._handle_event(workflow_entry, event)
try:
self._check_app_memory_updates(variable_pool)
except Exception as e:
logger.exception("Failed to check app memory updates", exc_info=e)
@override
def _handle_event(self, workflow_entry: WorkflowEntry, event: Any) -> None:
super()._handle_event(workflow_entry, event)
if isinstance(event, GraphRunSucceededEvent):
workflow_outputs = event.outputs
if not workflow_outputs:
logger.warning("Chatflow output is empty.")
return
assistant_message = workflow_outputs.get('answer')
if not assistant_message:
logger.warning("Chatflow output does not contain 'answer'.")
return
if not isinstance(assistant_message, str):
logger.warning("Chatflow output 'answer' is not a string.")
return
try:
self._sync_conversation_to_chatflow_tables(assistant_message)
except Exception as e:
logger.exception("Failed to sync conversation to memory tables", exc_info=e)
def handle_input_moderation(
self,
app_record: App,
@ -374,3 +411,67 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Return combined list
return existing_variables + new_variables
def _fetch_memory_blocks(self) -> Mapping[str, str]:
"""fetch all memory blocks for current app"""
memory_blocks_dict: MutableMapping[str, str] = {}
is_draft = (self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER)
conversation_id = self.conversation.id
memory_block_specs = self._workflow.memory_blocks
# Get runtime memory values
memories = ChatflowMemoryService.get_memories_by_specs(
memory_block_specs=memory_block_specs,
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
node_id=None,
conversation_id=conversation_id,
is_draft=is_draft,
created_by=self._get_created_by(),
)
# Build memory_id -> value mapping
for memory in memories:
if memory.spec.scope == MemoryScope.APP:
# App level: use memory_id directly
memory_blocks_dict[memory.spec.id] = memory.value
else: # NODE scope
node_id = memory.node_id
if not node_id:
logger.warning("Memory block %s has no node_id, skip.", memory.spec.id)
continue
key = f"{node_id}.{memory.spec.id}"
memory_blocks_dict[key] = memory.value
return memory_blocks_dict
def _sync_conversation_to_chatflow_tables(self, assistant_message: str):
ChatflowHistoryService.save_app_message(
prompt_message=UserPromptMessage(content=(self.application_generate_entity.query)),
conversation_id=self.conversation.id,
app_id=self._workflow.app_id,
tenant_id=self._workflow.tenant_id
)
ChatflowHistoryService.save_app_message(
prompt_message=AssistantPromptMessage(content=assistant_message),
conversation_id=self.conversation.id,
app_id=self._workflow.app_id,
tenant_id=self._workflow.tenant_id
)
def _check_app_memory_updates(self, variable_pool: VariablePool):
is_draft = (self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER)
ChatflowMemoryService.update_app_memory_if_needed(
workflow=self._workflow,
conversation_id=self.conversation.id,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_created_by()
)
def _get_created_by(self) -> MemoryCreatedBy:
if self.application_generate_entity.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
return MemoryCreatedBy(account_id=self.application_generate_entity.user_id)
else:
return MemoryCreatedBy(end_user_id=self.application_generate_entity.user_id)

View File

@ -56,6 +56,9 @@ class HostingConfiguration:
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
self.moderation_config = self.init_moderation_config()
@ -128,7 +131,7 @@ class HostingConfiguration:
quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
@ -156,18 +159,49 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
@staticmethod
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
def init_gemini(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
}
if dify_config.HOSTED_GEMINI_API_BASE:
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_anthropic(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_quota = PaidHostingQuota()
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
@ -185,6 +219,66 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
def init_xai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_XAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_XAI_API_KEY,
}
if dify_config.HOSTED_XAI_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_deepseek(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
}
if dify_config.HOSTED_DEEPSEEK_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
@staticmethod
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS

View File

@ -14,10 +14,12 @@ from core.llm_generator.prompts import (
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
MEMORY_UPDATE_PROMPT,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.memory.entities import MemoryBlock, MemoryBlockSpec
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
@ -27,6 +29,7 @@ from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from extensions.ext_storage import storage
@ -560,3 +563,35 @@ class LLMGenerator:
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
)
return {"error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
def update_memory_block(
tenant_id: str,
visible_history: Sequence[tuple[str, str]],
variable_pool: VariablePool,
memory_block: MemoryBlock,
memory_spec: MemoryBlockSpec
) -> str:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=memory_spec.model.provider,
model=memory_spec.model.name,
model_type=ModelType.LLM,
)
formatted_history = ""
for sender, message in visible_history:
formatted_history += f"{sender}: {message}\n"
filled_instruction = variable_pool.convert_template(memory_spec.instruction).text
formatted_prompt = PromptTemplateParser(MEMORY_UPDATE_PROMPT).format(
inputs={
"formatted_history": formatted_history,
"current_value": memory_block.value,
"instruction": filled_instruction,
}
)
llm_result = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content=formatted_prompt)],
model_parameters=memory_spec.model.completion_params,
stream=False,
)
return llm_result.message.get_text_content()

View File

@ -422,3 +422,18 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
MEMORY_UPDATE_PROMPT = """
Based on the following conversation history, update the memory content:
Conversation history:
{{formatted_history}}
Current memory:
{{current_value}}
Update instruction:
{{instruction}}
Please output only the updated memory content, no other text like greeting:
"""

119
api/core/memory/entities.py Normal file
View File

@ -0,0 +1,119 @@
from __future__ import annotations
from enum import StrEnum
from typing import TYPE_CHECKING, Optional
from uuid import uuid4
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from core.app.app_config.entities import ModelConfig
class MemoryScope(StrEnum):
"""Memory scope determined by node_id field"""
APP = "app" # node_id is None
NODE = "node" # node_id is not None
class MemoryTerm(StrEnum):
"""Memory term determined by conversation_id field"""
SESSION = "session" # conversation_id is not None
PERSISTENT = "persistent" # conversation_id is None
class MemoryStrategy(StrEnum):
ON_TURNS = "on_turns"
class MemoryScheduleMode(StrEnum):
SYNC = "sync"
ASYNC = "async"
class MemoryBlockSpec(BaseModel):
"""Memory block specification for workflow configuration"""
id: str = Field(
default_factory=lambda: str(uuid4()),
description="Unique identifier for the memory block",
)
name: str = Field(description="Display name of the memory block")
description: str = Field(default="", description="Description of the memory block")
template: str = Field(description="Initial template content for the memory")
instruction: str = Field(description="Instructions for updating the memory")
scope: MemoryScope = Field(description="Scope of the memory (app or node level)")
term: MemoryTerm = Field(description="Term of the memory (session or persistent)")
strategy: MemoryStrategy = Field(description="Update strategy for the memory")
update_turns: int = Field(gt=0, description="Number of turns between updates")
preserved_turns: int = Field(gt=0, description="Number of conversation turns to preserve")
schedule_mode: MemoryScheduleMode = Field(description="Synchronous or asynchronous update mode")
model: ModelConfig = Field(description="Model configuration for memory updates")
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
class MemoryCreatedBy(BaseModel):
end_user_id: str | None = None
account_id: str | None = None
class MemoryBlock(BaseModel):
"""Runtime memory block instance
Design Rules:
- app_id = None: Global memory (future feature, not implemented yet)
- app_id = str: App-specific memory
- conversation_id = None: Persistent memory (cross-conversation)
- conversation_id = str: Session memory (conversation-specific)
- node_id = None: App-level scope
- node_id = str: Node-level scope
These rules implicitly determine scope and term without redundant storage.
"""
spec: MemoryBlockSpec
tenant_id: str
value: str
app_id: str
conversation_id: Optional[str] = None
node_id: Optional[str] = None
edited_by_user: bool = False
created_by: MemoryCreatedBy
version: int = Field(description="Memory block version number")
class MemoryValueData(BaseModel):
value: str
edited_by_user: bool = False
class ChatflowConversationMetadata(BaseModel):
"""Metadata for chatflow conversation with visible message count"""
type: str = "mutable_visible_window"
visible_count: int = Field(gt=0, description="Number of visible messages to keep")
class MemoryBlockWithConversation(MemoryBlock):
"""MemoryBlock with optional conversation metadata for session memories"""
conversation_metadata: ChatflowConversationMetadata = Field(
description="Conversation metadata, only present for session memories"
)
@classmethod
def from_memory_block(
cls,
memory_block: MemoryBlock,
conversation_metadata: ChatflowConversationMetadata
) -> MemoryBlockWithConversation:
"""Create MemoryBlockWithConversation from MemoryBlock"""
return cls(
spec=memory_block.spec,
tenant_id=memory_block.tenant_id,
value=memory_block.value,
app_id=memory_block.app_id,
conversation_id=memory_block.conversation_id,
node_id=memory_block.node_id,
edited_by_user=memory_block.edited_by_user,
created_by=memory_block.created_by,
version=memory_block.version,
conversation_metadata=conversation_metadata
)

View File

@ -0,0 +1,6 @@
class MemorySyncTimeoutError(Exception):
def __init__(self, app_id: str, conversation_id: str):
self.app_id = app_id
self.conversation_id = conversation_id
self.message = "Memory synchronization timeout after 50 seconds"
super().__init__(self.message)

View File

@ -45,6 +45,12 @@ class MemoryConfig(BaseModel):
enabled: bool
size: int | None = None
mode: Literal["linear", "block"] | None = "linear"
block_id: list[str] | None = None
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
@property
def is_block_mode(self) -> bool:
return self.mode == "block" and bool(self.block_id)

View File

@ -618,9 +618,9 @@ class ProviderManager:
)
for quota in configuration.quotas:
if quota.quota_type == ProviderQuotaType.TRIAL:
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
if quota.quota_type not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
@ -628,8 +628,8 @@ class ProviderManager:
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=quota.quota_limit, # type: ignore
quota_type=quota.quota_type,
quota_limit=0, # type: ignore
quota_used=0,
is_valid=True,
)
@ -642,7 +642,7 @@ class ProviderManager:
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.quota_type == quota.quota_type,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@ -652,7 +652,7 @@ class ProviderManager:
existed_provider_record.is_valid = True
db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict
@ -912,6 +912,22 @@ class ProviderManager:
provider_record
)
quota_configurations = []
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL.value,
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.PAID.value,
)
else:
trail_pool = None
paid_pool = None
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
@ -932,16 +948,36 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
else:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configurations.append(quota_configuration)

View File

@ -203,6 +203,48 @@ class ArrayFileSegment(ArraySegment):
return ""
class VersionedMemoryValue(BaseModel):
current_value: str = None # type: ignore
versions: Mapping[str, str] = {}
model_config = ConfigDict(frozen=True)
def add_version(
self,
new_value: str,
version_name: str | None = None
) -> "VersionedMemoryValue":
if version_name is None:
version_name = str(len(self.versions) + 1)
if version_name in self.versions:
raise ValueError(f"Version '{version_name}' already exists.")
self.current_value = new_value
return VersionedMemoryValue(
current_value=new_value,
versions={
version_name: new_value,
**self.versions,
}
)
class VersionedMemorySegment(Segment):
value_type: SegmentType = SegmentType.VERSIONED_MEMORY
value: VersionedMemoryValue = None # type: ignore
@property
def text(self) -> str:
return self.value.current_value
@property
def log(self) -> str:
return self.value.current_value
@property
def markdown(self) -> str:
return self.value.current_value
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool] = None # type: ignore
@ -248,6 +290,7 @@ SegmentUnion: TypeAlias = Annotated[
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[VersionedMemorySegment, Tag(SegmentType.VERSIONED_MEMORY)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -41,6 +41,8 @@ class SegmentType(StrEnum):
ARRAY_FILE = "array[file]"
ARRAY_BOOLEAN = "array[boolean]"
VERSIONED_MEMORY = "versioned_memory"
NONE = "none"
GROUP = "group"

View File

@ -22,6 +22,7 @@ from .segments import (
ObjectSegment,
Segment,
StringSegment,
VersionedMemorySegment,
get_segment_discriminator,
)
from .types import SegmentType
@ -106,6 +107,10 @@ class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class VersionedMemoryVariable(VersionedMemorySegment, Variable):
pass
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass
@ -161,6 +166,7 @@ VariableUnion: TypeAlias = Annotated[
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
| Annotated[VersionedMemoryVariable, Tag(SegmentType.VERSIONED_MEMORY)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -1,4 +1,5 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
MEMORY_BLOCK_VARIABLE_NODE_ID = "memory_block"
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"

View File

@ -8,11 +8,12 @@ from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
from core.variables.segments import FileSegment, ObjectSegment, VersionedMemoryValue
from core.variables.variables import RAGPipelineVariableInput, VariableUnion, VersionedMemoryVariable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
MEMORY_BLOCK_VARIABLE_NODE_ID,
RAG_PIPELINE_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
@ -55,6 +56,10 @@ class VariablePool(BaseModel):
description="RAG pipeline variables.",
default_factory=list,
)
memory_blocks: Mapping[str, str] = Field(
description="Memory blocks.",
default_factory=dict,
)
def model_post_init(self, context: Any, /):
# Create a mapping from field names to SystemVariableKey enum values
@ -75,6 +80,18 @@ class VariablePool(BaseModel):
rag_pipeline_variables_map[node_id][key] = value
for key, value in rag_pipeline_variables_map.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
# Add memory blocks to the variable pool
for memory_id, memory_value in self.memory_blocks.items():
self.add(
[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id],
VersionedMemoryVariable(
value=VersionedMemoryValue(
current_value=memory_value,
versions={"1": memory_value},
),
name=memory_id,
)
)
def add(self, selector: Sequence[str], value: Any, /):
"""

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@ -136,21 +136,36 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
session.execute(stmt)
session.commit()
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -6,11 +6,15 @@ import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
@ -71,6 +75,8 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from models import UserFrom, Workflow
from models.engine import db
from . import llm_utils
from .entities import (
@ -315,6 +321,11 @@ class LLMNode(Node):
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
try:
self._handle_chatflow_memory(result_text, variable_pool)
except Exception as e:
logger.warning("Memory orchestration failed for node %s: %s", self.node_id, str(e))
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
@ -1184,6 +1195,79 @@ class LLMNode(Node):
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
def _handle_chatflow_memory(self, llm_output: str, variable_pool: VariablePool):
if not self._node_data.memory or self._node_data.memory.mode != "block":
return
conversation_id_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.CONVERSATION_ID))
if not conversation_id_segment:
raise ValueError("Conversation ID not found in variable pool.")
conversation_id = conversation_id_segment.text
user_query_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not user_query_segment:
raise ValueError("User query not found in variable pool.")
user_query = user_query_segment.text
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from services.chatflow_history_service import ChatflowHistoryService
ChatflowHistoryService.save_node_message(
prompt_message=(UserPromptMessage(content=user_query)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id
)
ChatflowHistoryService.save_node_message(
prompt_message=(AssistantPromptMessage(content=llm_output)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id
)
memory_config = self._node_data.memory
if not memory_config:
return
block_ids = memory_config.block_id
if not block_ids:
return
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
with Session(db.engine) as session:
stmt = select(Workflow).where(
Workflow.tenant_id == self.tenant_id,
Workflow.app_id == self.app_id
)
workflow = session.scalars(stmt).first()
if not workflow:
raise ValueError("Workflow not found.")
memory_blocks = workflow.memory_blocks
for block_id in block_ids:
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
from services.chatflow_memory_service import ChatflowMemoryService
ChatflowMemoryService.update_node_memory_if_needed(
tenant_id=self.tenant_id,
app_id=self.app_id,
node_id=self.id,
conversation_id=conversation_id,
memory_block_spec=memory_block_spec,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_user_from_context()
)
def _get_user_from_context(self) -> MemoryCreatedBy:
if self.user_from == UserFrom.ACCOUNT:
return MemoryCreatedBy(account_id=self.user_id)
else:
return MemoryCreatedBy(end_user_id=self.user_id)
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole

View File

@ -38,14 +38,16 @@ elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
else
if [[ "${DEBUG}" == "true" ]]; then
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0}
export PORT=${DIFY_PORT:-5001}
exec python -m app
else
exec gunicorn \
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:app
app:socketio_app
fi
fi

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
@ -133,22 +133,38 @@ def handle(sender: Message, **kwargs):
system_configuration=system_configuration,
model_name=model_config.model,
)
if used_quota is not None:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
credits_required=used_quota,
pool_type="trial",
)
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()

View File

@ -0,0 +1,3 @@
import socketio
sio = socketio.Server(async_mode="gevent", cors_allowed_origins="*")

View File

@ -21,6 +21,8 @@ from core.variables.segments import (
ObjectSegment,
Segment,
StringSegment,
VersionedMemorySegment,
VersionedMemoryValue,
)
from core.variables.types import SegmentType
from core.variables.variables import (
@ -39,6 +41,7 @@ from core.variables.variables import (
SecretVariable,
StringVariable,
Variable,
VersionedMemoryVariable,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@ -69,6 +72,7 @@ SEGMENT_TO_VARIABLE_MAP = {
NoneSegment: NoneVariable,
ObjectSegment: ObjectVariable,
StringSegment: StringVariable,
VersionedMemorySegment: VersionedMemoryVariable
}
@ -193,6 +197,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.FILE: FileSegment,
SegmentType.BOOLEAN: BooleanSegment,
SegmentType.OBJECT: ObjectSegment,
SegmentType.VERSIONED_MEMORY: VersionedMemorySegment,
# Array types
SegmentType.ARRAY_ANY: ArrayAnySegment,
SegmentType.ARRAY_STRING: ArrayStringSegment,
@ -259,6 +264,12 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
if segment_type == SegmentType.VERSIONED_MEMORY:
return VersionedMemorySegment(
value_type=segment_type,
value=VersionedMemoryValue.model_validate(value)
)
inferred_type = SegmentType.infer_segment_type(value)
# Type compatibility checking
if inferred_type is None:

View File

@ -0,0 +1,17 @@
from flask_restx import fields
online_user_partial_fields = {
"user_id": fields.String,
"username": fields.String,
"avatar": fields.String,
"sid": fields.String,
}
workflow_online_users_fields = {
"workflow_id": fields.String,
"users": fields.List(fields.Nested(online_user_partial_fields)),
}
online_user_list_fields = {
"data": fields.List(fields.Nested(workflow_online_users_fields)),
}

View File

@ -0,0 +1,96 @@
from flask_restx import fields
from libs.helper import AvatarUrlField, TimestampField
# basic account fields for comments
account_fields = {
"id": fields.String,
"name": fields.String,
"email": fields.String,
"avatar_url": AvatarUrlField,
}
# Comment mention fields
workflow_comment_mention_fields = {
"mentioned_user_id": fields.String,
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
"reply_id": fields.String,
}
# Comment reply fields
workflow_comment_reply_fields = {
"id": fields.String,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
}
# Basic comment fields (for list views)
workflow_comment_basic_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"reply_count": fields.Integer,
"mention_count": fields.Integer,
"participants": fields.List(fields.Nested(account_fields)),
}
# Detailed comment fields (for single comment view)
workflow_comment_detail_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
}
# Comment creation response fields (simplified)
workflow_comment_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Comment update response fields (simplified)
workflow_comment_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}
# Comment resolve response fields
workflow_comment_resolve_fields = {
"id": fields.String,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
}
# Reply creation response fields (simplified)
workflow_comment_reply_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Reply update response fields
workflow_comment_reply_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}

View File

@ -49,6 +49,23 @@ conversation_variable_fields = {
"description": fields.String,
}
memory_block_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"template": fields.String,
"instruction": fields.String,
"scope": fields.String,
"term": fields.String,
"strategy": fields.String,
"update_turns": fields.Integer,
"preserved_turns": fields.Integer,
"schedule_mode": fields.String,
"model": fields.Raw,
"end_user_visible": fields.Boolean,
"end_user_editable": fields.Boolean,
}
pipeline_variable_fields = {
"label": fields.String,
"variable": fields.String,
@ -81,6 +98,7 @@ workflow_fields = {
"tool_published": fields.Boolean,
"environment_variables": fields.List(EnvironmentVariableField()),
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
"memory_blocks": fields.List(fields.Nested(memory_block_fields)),
"rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)),
}

View File

@ -0,0 +1,90 @@
"""Add workflow comments table
Revision ID: 227822d22895
Revises: 68519ad5cd18
Create Date: 2025-08-22 17:26:15.255980
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '227822d22895'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('workflow_comments',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('position_x', sa.Float(), nullable=False),
sa.Column('position_y', sa.Float(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('resolved_at', sa.DateTime(), nullable=True),
sa.Column('resolved_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey')
)
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False)
batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_replies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey')
)
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_mentions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('reply_id', models.types.StringUUID(), nullable=True),
sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'),
sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey')
)
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False)
batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.drop_index('comment_mentions_user_idx')
batch_op.drop_index('comment_mentions_reply_idx')
batch_op.drop_index('comment_mentions_comment_idx')
op.drop_table('workflow_comment_mentions')
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.drop_index('comment_replies_created_at_idx')
batch_op.drop_index('comment_replies_comment_idx')
op.drop_table('workflow_comment_replies')
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.drop_index('workflow_comments_created_at_idx')
batch_op.drop_index('workflow_comments_app_idx')
op.drop_table('workflow_comments')
# ### end Alembic commands ###

View File

@ -0,0 +1,79 @@
"""add table explore banner and trial
Revision ID: 1b435d90db42
Revises: cf7c38a32b2d
Create Date: 2025-09-19 14:42:58.416649
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1b435d90db42'
down_revision = 'cf7c38a32b2d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_trial_app_records',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('account_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('count', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
)
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
op.create_table('exporle_banners',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('content', sa.JSON(), nullable=False),
sa.Column('link', sa.String(length=255), nullable=False),
sa.Column('sort', sa.Integer(), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
)
op.create_table('trial_apps',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('trial_limit', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
)
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.drop_column('credential_status')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.drop_index('trial_app_tenant_id_idx')
batch_op.drop_index('trial_app_app_id_idx')
op.drop_table('trial_apps')
op.drop_table('exporle_banners')
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.drop_index('account_trial_app_record_app_id_idx')
batch_op.drop_index('account_trial_app_record_account_id_idx')
op.drop_table('account_trial_app_records')
# ### end Alembic commands ###

View File

@ -0,0 +1,104 @@
"""add table credit pool
Revision ID: 58a70d22fdbd
Revises: 68519ad5cd18
Create Date: 2025-09-25 15:20:40.367078
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '58a70d22fdbd'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
sa.Column('quota_used', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
)
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
# Data migration: Move quota data from providers to tenant_credit_pools
migrate_quota_data()
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
batch_op.drop_index('tenant_credit_pool_pool_type_idx')
op.drop_table('tenant_credit_pools')
# ### end Alembic commands ###
def migrate_quota_data():
"""
Migrate quota data from providers table to tenant_credit_pools table
for providers with quota_type='trial' or 'paid', provider_name='openai', provider_type='system'
"""
# Create connection
bind = op.get_bind()
# Define quota type mappings
quota_type_mappings = ['trial', 'paid']
for quota_type in quota_type_mappings:
# Query providers that match the criteria
select_sql = sa.text("""
SELECT tenant_id, quota_limit, quota_used
FROM providers
WHERE quota_type = :quota_type
AND provider_name = 'openai'
AND provider_type = 'system'
AND quota_limit IS NOT NULL
""")
result = bind.execute(select_sql, {"quota_type": quota_type})
providers_data = result.fetchall()
# Insert data into tenant_credit_pools
for provider_data in providers_data:
tenant_id, quota_limit, quota_used = provider_data
# Check if credit pool already exists for this tenant and pool type
check_sql = sa.text("""
SELECT COUNT(*)
FROM tenant_credit_pools
WHERE tenant_id = :tenant_id AND pool_type = :pool_type
""")
existing_count = bind.execute(check_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type
}).scalar()
if existing_count == 0:
# Insert new credit pool record
insert_sql = sa.text("""
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
VALUES (:tenant_id, :pool_type, :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""")
bind.execute(insert_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type,
"quota_limit": quota_limit or 0,
"quota_used": quota_used or 0
})

View File

@ -0,0 +1,104 @@
"""add_chatflow_memory_tables
Revision ID: d00b2b40ea3e
Revises: 68519ad5cd18
Create Date: 2025-10-11 15:29:20.244675
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd00b2b40ea3e'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('chatflow_conversations',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('node_id', sa.Text(), nullable=True),
sa.Column('original_conversation_id', models.types.StringUUID(), nullable=True),
sa.Column('conversation_metadata', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='chatflow_conversations_pkey')
)
with op.batch_alter_table('chatflow_conversations', schema=None) as batch_op:
batch_op.create_index('chatflow_conversations_original_conversation_id_idx', ['tenant_id', 'app_id', 'node_id', 'original_conversation_id'], unique=False)
op.create_table('chatflow_memory_variables',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=True),
sa.Column('conversation_id', models.types.StringUUID(), nullable=True),
sa.Column('node_id', sa.Text(), nullable=True),
sa.Column('memory_id', sa.Text(), nullable=False),
sa.Column('value', sa.Text(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('scope', sa.String(length=10), nullable=False),
sa.Column('term', sa.String(length=20), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('created_by_role', sa.String(length=20), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='chatflow_memory_variables_pkey')
)
with op.batch_alter_table('chatflow_memory_variables', schema=None) as batch_op:
batch_op.create_index('chatflow_memory_variables_memory_id_idx', ['tenant_id', 'app_id', 'node_id', 'memory_id'], unique=False)
op.create_table('chatflow_messages',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),
sa.Column('index', sa.Integer(), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('data', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='chatflow_messages_pkey')
)
with op.batch_alter_table('chatflow_messages', schema=None) as batch_op:
batch_op.create_index('chatflow_messages_version_idx', ['conversation_id', 'index', 'version'], unique=False)
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('avatar_url',
existing_type=sa.TEXT(),
type_=sa.String(length=255),
existing_nullable=True)
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.drop_column('credential_status')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('avatar_url',
existing_type=sa.String(length=255),
type_=sa.TEXT(),
existing_nullable=True)
with op.batch_alter_table('chatflow_messages', schema=None) as batch_op:
batch_op.drop_index('chatflow_messages_version_idx')
op.drop_table('chatflow_messages')
with op.batch_alter_table('chatflow_memory_variables', schema=None) as batch_op:
batch_op.drop_index('chatflow_memory_variables_memory_id_idx')
op.drop_table('chatflow_memory_variables')
with op.batch_alter_table('chatflow_conversations', schema=None) as batch_op:
batch_op.drop_index('chatflow_conversations_original_conversation_id_idx')
op.drop_table('chatflow_conversations')
# ### end Alembic commands ###

View File

@ -9,6 +9,12 @@ from .account import (
TenantStatus,
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .chatflow_memory import ChatflowConversation, ChatflowMemoryVariable, ChatflowMessage
from .comment import (
WorkflowComment,
WorkflowCommentMention,
WorkflowCommentReply,
)
from .dataset import (
AppDatasetJoin,
Dataset,
@ -28,6 +34,7 @@ from .dataset import (
)
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
from .model import (
AccountTrialAppRecord,
ApiRequest,
ApiToken,
App,
@ -40,6 +47,7 @@ from .model import (
DatasetRetrieverResource,
DifySetup,
EndUser,
ExporleBanner,
IconType,
InstalledApp,
Message,
@ -53,7 +61,9 @@ from .model import (
Site,
Tag,
TagBinding,
TenantCreditPool,
TraceAppConfig,
TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -98,6 +108,7 @@ __all__ = [
"Account",
"AccountIntegrate",
"AccountStatus",
"AccountTrialAppRecord",
"ApiRequest",
"ApiToken",
"ApiToolProvider",
@ -111,6 +122,9 @@ __all__ = [
"BuiltinToolProvider",
"CeleryTask",
"CeleryTaskSet",
"ChatflowConversation",
"ChatflowMemoryVariable",
"ChatflowMessage",
"Conversation",
"ConversationVariable",
"CreatorUserRole",
@ -131,6 +145,7 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
@ -159,6 +174,7 @@ __all__ = [
"Tenant",
"TenantAccountJoin",
"TenantAccountRole",
"TenantCreditPool",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
@ -168,12 +184,16 @@ __all__ = [
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
"TrialApp",
"UploadFile",
"UserFrom",
"Whitelist",
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowComment",
"WorkflowCommentMention",
"WorkflowCommentReply",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionOffload",
"WorkflowNodeExecutionTriggeredFrom",

View File

@ -0,0 +1,76 @@
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy import DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .types import StringUUID
class ChatflowMemoryVariable(Base):
__tablename__ = "chatflow_memory_variables"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="chatflow_memory_variables_pkey"),
sa.Index("chatflow_memory_variables_memory_id_idx", "tenant_id", "app_id", "node_id", "memory_id"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
node_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
memory_id: Mapped[str] = mapped_column(sa.Text, nullable=False)
value: Mapped[str] = mapped_column(sa.Text, nullable=False)
name: Mapped[str] = mapped_column(sa.Text, nullable=False)
scope: Mapped[str] = mapped_column(sa.String(10), nullable=False) # 'app' or 'node'
term: Mapped[str] = mapped_column(sa.String(20), nullable=False) # 'session' or 'persistent'
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
created_by_role: Mapped[str] = mapped_column(sa.String(20)) # 'end_user' or 'account`
created_by: Mapped[str] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ChatflowConversation(Base):
__tablename__ = "chatflow_conversations"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="chatflow_conversations_pkey"),
sa.Index(
"chatflow_conversations_original_conversation_id_idx",
"tenant_id", "app_id", "node_id", "original_conversation_id"
),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
original_conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
conversation_metadata: Mapped[str] = mapped_column(sa.Text, nullable=False) # JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ChatflowMessage(Base):
__tablename__ = "chatflow_messages"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="chatflow_messages_pkey"),
sa.Index("chatflow_messages_version_idx", "conversation_id", "index", "version"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
index: Mapped[int] = mapped_column(sa.Integer, nullable=False) # This index starts from 0
version: Mapped[int] = mapped_column(sa.Integer, nullable=False)
data: Mapped[str] = mapped_column(sa.Text, nullable=False) # Serialized PromptMessage JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

189
api/models/comment.py Normal file
View File

@ -0,0 +1,189 @@
"""Workflow comment models."""
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .account import Account
from .base import Base
from .engine import db
from .types import StringUUID
if TYPE_CHECKING:
pass
class WorkflowComment(Base):
"""Workflow comment model for canvas commenting functionality.
Comments are associated with apps rather than specific workflow versions,
since an app has only one draft workflow at a time and comments should persist
across workflow version changes.
Attributes:
id: Comment ID
tenant_id: Workspace ID
app_id: App ID (primary association, comments belong to apps)
position_x: X coordinate on canvas
position_y: Y coordinate on canvas
content: Comment content
created_by: Creator account ID
created_at: Creation time
updated_at: Last update time
resolved: Whether comment is resolved
resolved_at: Resolution time
resolved_by: Resolver account ID
"""
__tablename__ = "workflow_comments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(db.Float)
position_y: Mapped[float] = mapped_column(db.Float)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
resolved_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
resolved_by: Mapped[Optional[str]] = mapped_column(StringUUID)
# Relationships
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
)
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
)
@property
def created_by_account(self):
"""Get creator account."""
return db.session.get(Account, self.created_by)
@property
def resolved_by_account(self):
"""Get resolver account."""
if self.resolved_by:
return db.session.get(Account, self.resolved_by)
return None
@property
def reply_count(self):
"""Get reply count."""
return len(self.replies)
@property
def mention_count(self):
"""Get mention count."""
return len(self.mentions)
@property
def participants(self):
"""Get all participants (creator + repliers + mentioned users)."""
participant_ids = set()
# Add comment creator
participant_ids.add(self.created_by)
# Add reply creators
participant_ids.update(reply.created_by for reply in self.replies)
# Add mentioned users
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
# Get account objects
participants = []
for user_id in participant_ids:
account = db.session.get(Account, user_id)
if account:
participants.append(account)
return participants
class WorkflowCommentReply(Base):
"""Workflow comment reply model.
Attributes:
id: Reply ID
comment_id: Parent comment ID
content: Reply content
created_by: Creator account ID
created_at: Creation time
"""
__tablename__ = "workflow_comment_replies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@property
def created_by_account(self):
"""Get creator account."""
return db.session.get(Account, self.created_by)
class WorkflowCommentMention(Base):
"""Workflow comment mention model.
Mentions are only for internal accounts since end users
cannot access workflow canvas and commenting features.
Attributes:
id: Mention ID
comment_id: Parent comment ID
mentioned_user_id: Mentioned account ID
"""
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[Optional[str]] = mapped_column(
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
@property
def mentioned_user_account(self):
"""Get mentioned account."""
return db.session.get(Account, self.mentioned_user_id)

View File

@ -23,6 +23,7 @@ class DraftVariableType(StrEnum):
NODE = "node"
SYS = "sys"
CONVERSATION = "conversation"
MEMORY_BLOCK = "memory_block"
class MessageStatus(StrEnum):

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
@ -581,6 +581,63 @@ class InstalledApp(Base):
return tenant
class TrialApp(Base):
__tablename__ = "trial_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
sa.Index("trial_app_app_id_idx", "app_id"),
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
class AccountTrialAppRecord(Base):
__tablename__ = "account_trial_app_records"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
class ExporleBanner(Base):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
content = mapped_column(sa.JSON, nullable=False)
link = mapped_column(String(255), nullable=False)
sort = mapped_column(sa.Integer, nullable=False)
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class OAuthProviderApp(Base):
"""
Globally shared OAuth provider app information.
@ -1944,3 +2001,29 @@ class TraceAppConfig(Base):
"created_at": str(self.created_at) if self.created_at else None,
"updated_at": str(self.updated_at) if self.updated_at else None,
}
class TenantCreditPool(Base):
__tablename__ = "tenant_credit_pools"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"),
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit = mapped_column(BigInteger, nullable=False, default=0)
quota_used = mapped_column(BigInteger, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def remaining_credits(self) -> int:
return max(0, self.quota_limit - self.quota_used)
def has_sufficient_credits(self, required_credits: int) -> bool:
return self.remaining_credits >= required_credits

View File

@ -1,5 +1,6 @@
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
@ -11,9 +12,15 @@ from sqlalchemy import DateTime, Select, exists, orm, select
from core.file.constants import maybe_file_object
from core.file.models import File
from core.memory.entities import MemoryBlockSpec
from core.variables import utils as variable_utils
from core.variables.segments import VersionedMemoryValue
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
MEMORY_BLOCK_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.enums import NodeType
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
@ -149,6 +156,9 @@ class Workflow(Base):
_rag_pipeline_variables: Mapped[str] = mapped_column(
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
)
_memory_blocks: Mapped[str] = mapped_column(
"memory_blocks", sa.Text, nullable=False, server_default="[]"
)
VERSION_DRAFT = "draft"
@ -166,6 +176,7 @@ class Workflow(Base):
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
rag_pipeline_variables: list[dict],
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
marked_name: str = "",
marked_comment: str = "",
) -> "Workflow":
@ -181,6 +192,7 @@ class Workflow(Base):
workflow.environment_variables = environment_variables or []
workflow.conversation_variables = conversation_variables or []
workflow.rag_pipeline_variables = rag_pipeline_variables or []
workflow.memory_blocks = memory_blocks or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
workflow.created_at = naive_utc_now()
@ -335,7 +347,7 @@ class Workflow(Base):
:return: hash
"""
entity = {"graph": self.graph_dict, "features": self.features_dict}
entity = {"graph": self.graph_dict}
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
@ -440,7 +452,7 @@ class Workflow(Base):
"features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
"rag_pipeline_variables": self.rag_pipeline_variables,
"memory_blocks": [block.model_dump(mode="json") for block in self.memory_blocks],
}
return result
@ -478,6 +490,27 @@ class Workflow(Base):
ensure_ascii=False,
)
@property
def memory_blocks(self) -> Sequence[MemoryBlockSpec]:
"""Memory blocks configuration stored in database"""
if self._memory_blocks is None or self._memory_blocks == "":
self._memory_blocks = "[]"
memory_blocks_list: list[dict[str, Any]] = json.loads(self._memory_blocks)
results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_list]
return results
@memory_blocks.setter
def memory_blocks(self, value: Sequence[MemoryBlockSpec]):
if not value:
self._memory_blocks = "[]"
return
self._memory_blocks = json.dumps(
[block.model_dump() for block in value],
ensure_ascii=False,
)
@staticmethod
def version_from_datetime(d: datetime) -> str:
return str(d)
@ -1489,6 +1522,31 @@ class WorkflowDraftVariable(Base):
variable.editable = editable
return variable
@staticmethod
def new_memory_block_variable(
*,
app_id: str,
node_id: str | None = None,
memory_id: str,
name: str,
value: VersionedMemoryValue,
description: str = "",
) -> "WorkflowDraftVariable":
"""Create a new memory block draft variable."""
return WorkflowDraftVariable(
id=str(uuid.uuid4()),
app_id=app_id,
node_id=MEMORY_BLOCK_VARIABLE_NODE_ID,
name=name,
value=value.model_dump_json(),
description=description,
selector=[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id] if node_id is None else
[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id, node_id],
value_type=SegmentType.VERSIONED_MEMORY,
visible=True,
editable=True,
)
@property
def edited(self):
return self.last_edited_at is not None

View File

@ -20,6 +20,7 @@ dependencies = [
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gevent-websocket~=0.10.1",
"gmpy2~=2.2.1",
"google-api-core==2.18.0",
"google-api-python-client==2.90.0",
@ -68,6 +69,7 @@ dependencies = [
"pypdfium2==4.30.0",
"python-docx~=1.1.0",
"python-dotenv==1.0.1",
"python-socketio~=5.13.0",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=6.1.0",
@ -86,6 +88,7 @@ dependencies = [
"sendgrid~=6.12.3",
"flask-restx~=1.3.0",
"packaging~=23.2",
"gevent-websocket>=0.10.1",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.

View File

@ -995,6 +995,11 @@ class TenantService:
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
from services.credit_pool_service import CreditPoolService
CreditPoolService.create_default_pool(tenant.id)
return tenant
@staticmethod

View File

@ -0,0 +1,237 @@
import json
from collections.abc import MutableMapping, Sequence
from typing import Literal, Optional, overload
from sqlalchemy import Row, Select, and_, func, select
from sqlalchemy.orm import Session
from core.memory.entities import ChatflowConversationMetadata
from core.model_runtime.entities.message_entities import (
PromptMessage,
)
from extensions.ext_database import db
from models.chatflow_memory import ChatflowConversation, ChatflowMessage
class ChatflowHistoryService:
@staticmethod
def get_visible_chat_history(
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
max_visible_count: Optional[int] = None
) -> Sequence[PromptMessage]:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=False
)
if not chatflow_conv:
return []
metadata = ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
visible_count: int = max_visible_count or metadata.visible_count
stmt = select(ChatflowMessage).where(
ChatflowMessage.conversation_id == chatflow_conv.id
).order_by(ChatflowMessage.index.asc(), ChatflowMessage.version.desc())
raw_messages: Sequence[Row[tuple[ChatflowMessage]]] = session.execute(stmt).all()
sorted_messages = ChatflowHistoryService._filter_latest_messages(
[it[0] for it in raw_messages]
)
visible_count = min(visible_count, len(sorted_messages))
visible_messages = sorted_messages[-visible_count:]
return [PromptMessage.model_validate_json(it.data) for it in visible_messages]
@staticmethod
def save_message(
prompt_message: PromptMessage,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None
) -> None:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True
)
# Get next index
max_index = session.execute(
select(func.max(ChatflowMessage.index)).where(
ChatflowMessage.conversation_id == chatflow_conv.id
)
).scalar() or -1
next_index = max_index + 1
# Save new message to append-only table
new_message = ChatflowMessage(
conversation_id=chatflow_conv.id,
index=next_index,
version=1,
data=json.dumps(prompt_message)
)
session.add(new_message)
session.commit()
# 添加每次保存消息后简单增长visible_count
current_metadata = ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
new_visible_count = current_metadata.visible_count + 1
new_metadata = ChatflowConversationMetadata(visible_count=new_visible_count)
chatflow_conv.conversation_metadata = new_metadata.model_dump_json()
@staticmethod
def save_app_message(
prompt_message: PromptMessage,
conversation_id: str,
app_id: str,
tenant_id: str
) -> None:
"""Save PromptMessage to app-level chatflow conversation."""
ChatflowHistoryService.save_message(
prompt_message=prompt_message,
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=None
)
@staticmethod
def save_node_message(
prompt_message: PromptMessage,
node_id: str,
conversation_id: str,
app_id: str,
tenant_id: str
) -> None:
ChatflowHistoryService.save_message(
prompt_message=prompt_message,
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=node_id
)
@staticmethod
def update_visible_count(
conversation_id: str,
node_id: Optional[str],
new_visible_count: int,
app_id: str,
tenant_id: str
) -> None:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True
)
# Only update visible_count in metadata, do not delete any data
new_metadata = ChatflowConversationMetadata(visible_count=new_visible_count)
chatflow_conv.conversation_metadata = new_metadata.model_dump_json()
session.commit()
@staticmethod
def get_conversation_metadata(
tenant_id: str,
app_id: str,
conversation_id: str,
node_id: Optional[str]
) -> ChatflowConversationMetadata:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=False
)
if not chatflow_conv:
raise ValueError(f"Conversation not found: {conversation_id}")
return ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
@staticmethod
def _filter_latest_messages(raw_messages: Sequence[ChatflowMessage]) -> Sequence[ChatflowMessage]:
index_to_message: MutableMapping[int, ChatflowMessage] = {}
for msg in raw_messages:
index = msg.index
if index not in index_to_message or msg.version > index_to_message[index].version:
index_to_message[index] = msg
sorted_messages = sorted(index_to_message.values(), key=lambda m: m.index)
return sorted_messages
@overload
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: Literal[True] = True
) -> ChatflowConversation: ...
@overload
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: Literal[False] = False
) -> Optional[ChatflowConversation]: ...
@overload
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: bool = False
) -> Optional[ChatflowConversation]: ...
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: bool = False
) -> Optional[ChatflowConversation]:
"""Get existing chatflow conversation or optionally create new one"""
stmt: Select[tuple[ChatflowConversation]] = select(ChatflowConversation).where(
and_(
ChatflowConversation.original_conversation_id == conversation_id,
ChatflowConversation.tenant_id == tenant_id,
ChatflowConversation.app_id == app_id
)
)
if node_id:
stmt = stmt.where(ChatflowConversation.node_id == node_id)
else:
stmt = stmt.where(ChatflowConversation.node_id.is_(None))
chatflow_conv: Row[tuple[ChatflowConversation]] | None = session.execute(stmt).first()
if chatflow_conv:
result: ChatflowConversation = chatflow_conv[0] # Extract the ChatflowConversation object
return result
else:
if create_if_missing:
# Create a new chatflow conversation
default_metadata = ChatflowConversationMetadata(visible_count=0)
new_chatflow_conv = ChatflowConversation(
tenant_id=tenant_id,
app_id=app_id,
node_id=node_id,
original_conversation_id=conversation_id,
conversation_metadata=default_metadata.model_dump_json(),
)
session.add(new_chatflow_conv)
session.flush() # Obtain ID
return new_chatflow_conv
return None

View File

@ -0,0 +1,680 @@
import logging
import threading
import time
from collections.abc import Sequence
from typing import Optional
from sqlalchemy import and_, delete, select
from sqlalchemy.orm import Session
from core.llm_generator.llm_generator import LLMGenerator
from core.memory.entities import (
MemoryBlock,
MemoryBlockSpec,
MemoryBlockWithConversation,
MemoryCreatedBy,
MemoryScheduleMode,
MemoryScope,
MemoryTerm,
MemoryValueData,
)
from core.memory.errors import MemorySyncTimeoutError
from core.model_runtime.entities.message_entities import PromptMessage
from core.variables.segments import VersionedMemoryValue
from core.workflow.constants import MEMORY_BLOCK_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import App, CreatorUserRole
from models.chatflow_memory import ChatflowMemoryVariable
from models.workflow import Workflow, WorkflowDraftVariable
from services.chatflow_history_service import ChatflowHistoryService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class ChatflowMemoryService:
@staticmethod
def get_persistent_memories(
app: App,
created_by: MemoryCreatedBy,
version: int | None = None
) -> Sequence[MemoryBlock]:
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
if version is None:
# If version not specified, get the latest version
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
)
).order_by(ChatflowMemoryVariable.version.desc())
else:
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
ChatflowMemoryVariable.version == version
)
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
@staticmethod
def get_session_memories(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
) -> Sequence[MemoryBlock]:
if version is None:
# If version not specified, get the latest version
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == conversation_id
)
).order_by(ChatflowMemoryVariable.version.desc())
else:
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == conversation_id,
ChatflowMemoryVariable.version == version
)
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
@staticmethod
def save_memory(memory: MemoryBlock, variable_pool: VariablePool, is_draft: bool) -> None:
key = f"{memory.node_id}.{memory.spec.id}" if memory.node_id else memory.spec.id
variable_pool.add([MEMORY_BLOCK_VARIABLE_NODE_ID, key], memory.value)
if memory.created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by = memory.created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by = memory.created_by.id
with Session(db.engine) as session:
session.add(
ChatflowMemoryVariable(
memory_id=memory.spec.id,
tenant_id=memory.tenant_id,
app_id=memory.app_id,
node_id=memory.node_id,
conversation_id=memory.conversation_id,
name=memory.spec.name,
value=MemoryValueData(
value=memory.value,
edited_by_user=memory.edited_by_user
).model_dump_json(),
term=memory.spec.term,
scope=memory.spec.scope,
version=memory.version, # Use version from MemoryBlock directly
created_by_role=created_by_role,
created_by=created_by,
)
)
session.commit()
if is_draft:
with Session(bind=db.engine) as session:
draft_var_service = WorkflowDraftVariableService(session)
memory_selector = memory.spec.id if not memory.node_id else f"{memory.node_id}.{memory.spec.id}"
existing_vars = draft_var_service.get_draft_variables_by_selectors(
app_id=memory.app_id,
selectors=[[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_selector]]
)
if existing_vars:
draft_var = existing_vars[0]
draft_var.value = VersionedMemoryValue.model_validate_json(draft_var.value)\
.add_version(memory.value)\
.model_dump_json()
else:
draft_var = WorkflowDraftVariable.new_memory_block_variable(
app_id=memory.app_id,
memory_id=memory.spec.id,
name=memory.spec.name,
value=VersionedMemoryValue().add_version(memory.value),
description=memory.spec.description
)
session.add(draft_var)
session.commit()
@staticmethod
def get_memories_by_specs(
memory_block_specs: Sequence[MemoryBlockSpec],
tenant_id: str, app_id: str,
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool
) -> Sequence[MemoryBlock]:
return [ChatflowMemoryService.get_memory_by_spec(
spec, tenant_id, app_id, created_by, conversation_id, node_id, is_draft
) for spec in memory_block_specs]
@staticmethod
def get_memory_by_spec(
spec: MemoryBlockSpec,
tenant_id: str,
app_id: str,
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool
) -> MemoryBlock:
with Session(db.engine) as session:
if is_draft:
draft_var_service = WorkflowDraftVariableService(session)
selector = [MEMORY_BLOCK_VARIABLE_NODE_ID, f"{spec.id}.{node_id}"] \
if node_id else [MEMORY_BLOCK_VARIABLE_NODE_ID, spec.id]
draft_vars = draft_var_service.get_draft_variables_by_selectors(
app_id=app_id,
selectors=[selector]
)
if draft_vars:
draft_var = draft_vars[0]
return MemoryBlock(
value=draft_var.get_value().text,
tenant_id=tenant_id,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
created_by=created_by,
version=1,
)
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.memory_id == spec.id,
ChatflowMemoryVariable.tenant_id == tenant_id,
ChatflowMemoryVariable.app_id == app_id,
ChatflowMemoryVariable.node_id ==
(node_id if spec.scope == MemoryScope.NODE else None),
ChatflowMemoryVariable.conversation_id ==
(conversation_id if spec.term == MemoryTerm.SESSION else None),
)
).order_by(ChatflowMemoryVariable.version.desc()).limit(1)
result = session.execute(stmt).scalar()
if result:
memory_value_data = MemoryValueData.model_validate_json(result.value)
return MemoryBlock(
value=memory_value_data.value,
tenant_id=tenant_id,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
edited_by_user=memory_value_data.edited_by_user,
created_by=created_by,
version=result.version,
)
return MemoryBlock(
tenant_id=tenant_id,
value=spec.template,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
created_by=created_by,
version=1,
)
@staticmethod
def update_app_memory_if_needed(
workflow: Workflow,
conversation_id: str,
variable_pool: VariablePool,
created_by: MemoryCreatedBy,
is_draft: bool
):
visible_messages = ChatflowHistoryService.get_visible_chat_history(
conversation_id=conversation_id,
app_id=workflow.app_id,
tenant_id=workflow.tenant_id,
node_id=None,
)
sync_blocks: list[MemoryBlock] = []
async_blocks: list[MemoryBlock] = []
for memory_spec in workflow.memory_blocks:
if memory_spec.scope == MemoryScope.APP:
memory = ChatflowMemoryService.get_memory_by_spec(
spec=memory_spec,
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
conversation_id=conversation_id,
node_id=None,
is_draft=is_draft,
created_by=created_by,
)
if ChatflowMemoryService._should_update_memory(memory, visible_messages):
if memory.spec.schedule_mode == MemoryScheduleMode.SYNC:
sync_blocks.append(memory)
else:
async_blocks.append(memory)
if not sync_blocks and not async_blocks:
return
# async mode: submit individual async tasks directly
for memory_block in async_blocks:
ChatflowMemoryService._app_submit_async_memory_update(
block=memory_block,
is_draft=is_draft,
variable_pool=variable_pool,
visible_messages=visible_messages,
conversation_id=conversation_id,
)
# sync mode: submit a batch update task
if sync_blocks:
ChatflowMemoryService._app_submit_sync_memory_batch_update(
sync_blocks=sync_blocks,
is_draft=is_draft,
conversation_id=conversation_id,
app_id=workflow.app_id,
visible_messages=visible_messages,
variable_pool=variable_pool
)
@staticmethod
def update_node_memory_if_needed(
tenant_id: str,
app_id: str,
node_id: str,
created_by: MemoryCreatedBy,
conversation_id: str,
memory_block_spec: MemoryBlockSpec,
variable_pool: VariablePool,
is_draft: bool
) -> bool:
visible_messages = ChatflowHistoryService.get_visible_chat_history(
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=node_id,
)
memory_block = ChatflowMemoryService.get_memory_by_spec(
spec=memory_block_spec,
tenant_id=tenant_id,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
is_draft=is_draft,
created_by=created_by,
)
if not ChatflowMemoryService._should_update_memory(
memory_block=memory_block,
visible_history=visible_messages
):
return False
if memory_block_spec.schedule_mode == MemoryScheduleMode.SYNC:
# Node-level sync: blocking execution
ChatflowMemoryService._update_node_memory_sync(
visible_messages=visible_messages,
memory_block=memory_block,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id
)
else:
# Node-level async: execute asynchronously
ChatflowMemoryService._update_node_memory_async(
memory_block=memory_block,
visible_messages=visible_messages,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id
)
return True
@staticmethod
def wait_for_sync_memory_completion(workflow: Workflow, conversation_id: str):
"""Wait for sync memory update to complete, maximum 50 seconds"""
memory_blocks = workflow.memory_blocks
sync_memory_blocks = [
block for block in memory_blocks
if block.scope == MemoryScope.APP and block.schedule_mode == MemoryScheduleMode.SYNC
]
if not sync_memory_blocks:
return
lock_key = _get_memory_sync_lock_key(workflow.app_id, conversation_id)
# Retry up to 10 times, wait 5 seconds each time, total 50 seconds
max_retries = 10
retry_interval = 5
for i in range(max_retries):
if not redis_client.exists(lock_key):
# Lock doesn't exist, can continue
return
if i < max_retries - 1:
# Still have retry attempts, wait
time.sleep(retry_interval)
else:
# Maximum retry attempts reached, raise exception
raise MemorySyncTimeoutError(
app_id=workflow.app_id,
conversation_id=conversation_id
)
@staticmethod
def _convert_to_memory_blocks(
app: App,
created_by: MemoryCreatedBy,
raw_results: Sequence[ChatflowMemoryVariable]
) -> Sequence[MemoryBlock]:
workflow = WorkflowService().get_published_workflow(app)
if not workflow:
return []
results = []
for chatflow_memory_variable in raw_results:
spec = next(
(spec for spec in workflow.memory_blocks if spec.id == chatflow_memory_variable.memory_id),
None
)
if spec and chatflow_memory_variable.app_id:
memory_value_data = MemoryValueData.model_validate_json(chatflow_memory_variable.value)
results.append(
MemoryBlock(
spec=spec,
tenant_id=chatflow_memory_variable.tenant_id,
value=memory_value_data.value,
app_id=chatflow_memory_variable.app_id,
conversation_id=chatflow_memory_variable.conversation_id,
node_id=chatflow_memory_variable.node_id,
edited_by_user=memory_value_data.edited_by_user,
created_by=created_by,
version=chatflow_memory_variable.version,
)
)
return results
@staticmethod
def _should_update_memory(
memory_block: MemoryBlock,
visible_history: Sequence[PromptMessage]
) -> bool:
return len(visible_history) >= memory_block.spec.update_turns
@staticmethod
def _app_submit_async_memory_update(
block: MemoryBlock,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool
):
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
'memory_block': block,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft,
'conversation_id': conversation_id
},
)
thread.start()
@staticmethod
def _app_submit_sync_memory_batch_update(
sync_blocks: Sequence[MemoryBlock],
app_id: str,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
is_draft: bool
):
"""Submit sync memory batch update task"""
thread = threading.Thread(
target=ChatflowMemoryService._batch_update_sync_memory,
kwargs={
'sync_blocks': sync_blocks,
'app_id': app_id,
'conversation_id': conversation_id,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft
},
)
thread.start()
@staticmethod
def _batch_update_sync_memory(
sync_blocks: Sequence[MemoryBlock],
app_id: str,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
is_draft: bool
):
try:
lock_key = _get_memory_sync_lock_key(app_id, conversation_id)
with redis_client.lock(lock_key, timeout=120):
threads = []
for block in sync_blocks:
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
'memory_block': block,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft,
'conversation_id': conversation_id,
},
)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
except Exception as e:
logger.exception("Error batch updating memory", exc_info=e)
@staticmethod
def _update_node_memory_sync(
memory_block: MemoryBlock,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool
):
ChatflowMemoryService._perform_memory_update(
memory_block=memory_block,
visible_messages=visible_messages,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id
)
@staticmethod
def _update_node_memory_async(
memory_block: MemoryBlock,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool = False
):
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
'memory_block': memory_block,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft,
'conversation_id': conversation_id,
},
daemon=True
)
thread.start()
@staticmethod
def _perform_memory_update(
memory_block: MemoryBlock,
variable_pool: VariablePool,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
is_draft: bool
):
updated_value = LLMGenerator.update_memory_block(
tenant_id=memory_block.tenant_id,
visible_history=ChatflowMemoryService._format_chat_history(visible_messages),
variable_pool=variable_pool,
memory_block=memory_block,
memory_spec=memory_block.spec,
)
updated_memory = MemoryBlock(
tenant_id=memory_block.tenant_id,
value=updated_value,
spec=memory_block.spec,
app_id=memory_block.app_id,
conversation_id=conversation_id,
node_id=memory_block.node_id,
edited_by_user=False,
created_by=memory_block.created_by,
version=memory_block.version + 1, # Increment version for business logic update
)
ChatflowMemoryService.save_memory(updated_memory, variable_pool, is_draft)
ChatflowHistoryService.update_visible_count(
conversation_id=conversation_id,
node_id=memory_block.node_id,
new_visible_count=memory_block.spec.preserved_turns,
app_id=memory_block.app_id,
tenant_id=memory_block.tenant_id
)
@staticmethod
def delete_memory(app: App, memory_id: str, created_by: MemoryCreatedBy):
workflow = WorkflowService().get_published_workflow(app)
if not workflow:
raise ValueError("Workflow not found")
memory_spec = next((it for it in workflow.memory_blocks if it.id == memory_id), None)
if not memory_spec or not memory_spec.end_user_editable:
raise ValueError("Memory not found or not deletable")
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
with Session(db.engine) as session:
stmt = delete(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.memory_id == memory_id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id
)
)
session.execute(stmt)
session.commit()
@staticmethod
def delete_all_user_memories(app: App, created_by: MemoryCreatedBy):
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
with Session(db.engine) as session:
stmt = delete(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id
)
)
session.execute(stmt)
session.commit()
@staticmethod
def get_persistent_memories_with_conversation(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
) -> Sequence[MemoryBlockWithConversation]:
"""Get persistent memories with conversation metadata (always None for persistent)"""
memory_blocks = ChatflowMemoryService.get_persistent_memories(app, created_by, version)
return [
MemoryBlockWithConversation.from_memory_block(
block,
ChatflowHistoryService.get_conversation_metadata(
app.tenant_id, app.id, conversation_id, block.node_id
)
)
for block in memory_blocks
]
@staticmethod
def get_session_memories_with_conversation(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
) -> Sequence[MemoryBlockWithConversation]:
"""Get session memories with conversation metadata"""
memory_blocks = ChatflowMemoryService.get_session_memories(app, created_by, conversation_id, version)
return [
MemoryBlockWithConversation.from_memory_block(
block,
ChatflowHistoryService.get_conversation_metadata(
app.tenant_id, app.id, conversation_id, block.node_id
)
)
for block in memory_blocks
]
@staticmethod
def _format_chat_history(messages: Sequence[PromptMessage]) -> Sequence[tuple[str, str]]:
result = []
for message in messages:
result.append((str(message.role.value), message.get_text_content()))
return result
def _get_memory_sync_lock_key(app_id: str, conversation_id: str) -> str:
"""Generate Redis lock key for memory sync updates
Args:
app_id: Application ID
conversation_id: Conversation ID
Returns:
Formatted lock key
"""
return f"memory_sync_update:{app_id}:{conversation_id}"

View File

@ -0,0 +1,68 @@
import logging
from typing import Optional
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from models import TenantCreditPool
logger = logging.getLogger(__name__)
class CreditPoolService:
@classmethod
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
credit_pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
)
db.session.add(credit_pool)
db.session.commit()
return credit_pool
@classmethod
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]:
"""get tenant credit pool"""
return (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
pool_type=pool_type,
)
.first()
)
@classmethod
def check_and_deduct_credits(
cls,
tenant_id: str,
credits_required: int,
pool_type: str = "trial",
):
"""check and deduct credits"""
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
if pool.remaining_credits < credits_required:
raise QuotaExceededError(
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
)
try:
with Session(db.engine) as session:
update_values = {"quota_used": pool.quota_used + credits_required}
where_conditions = [
TenantCreditPool.pool_type == pool_type,
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
]
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
session.execute(stmt)
session.commit()
except Exception:
raise QuotaExceededError("Failed to deduct credits")

View File

@ -160,6 +160,8 @@ class SystemFeatureModel(BaseModel):
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
enable_trial_app: bool = False
enable_explore_banner: bool = False
class FeatureService:
@ -214,6 +216,8 @@ class FeatureService:
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):

View File

@ -1,4 +1,9 @@
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from models.model import AccountTrialAppRecord, TrialApp
from services.feature_service import FeatureService
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@ -20,6 +25,15 @@ class RecommendedAppService:
)
)
if FeatureService.get_system_features().enable_trial_app:
apps = result["recommended_apps"]
for app in apps:
app_id = app["app_id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
app["can_trial"] = True
else:
app["can_trial"] = False
return result
@classmethod
@ -32,4 +46,27 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
if FeatureService.get_system_features().enable_trial_app:
app_id = result["id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
result["can_trial"] = True
else:
result["can_trial"] = False
return result
@classmethod
def add_trial_app_record(cls, app_id: str, account_id: str):
"""
Add trial app record.
:param app_id: app id
:return:
"""
with Session(db.engine) as session:
account_trial_app_record = session.query(AccountTrialAppRecord).where(TrialApp.app_id == app_id).first()
if account_trial_app_record:
account_trial_app_record.count += 1
session.commit()
else:
session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
session.commit()

View File

@ -0,0 +1,311 @@
import logging
from typing import Optional
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, selectinload
from werkzeug.exceptions import Forbidden, NotFound
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from models import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
logger = logging.getLogger(__name__)
class WorkflowCommentService:
"""Service for managing workflow comments."""
@staticmethod
def _validate_content(content: str) -> None:
if len(content.strip()) == 0:
raise ValueError("Comment content cannot be empty")
if len(content) > 1000:
raise ValueError("Comment content cannot exceed 1000 characters")
@staticmethod
def get_comments(tenant_id: str, app_id: str) -> list[WorkflowComment]:
"""Get all comments for a workflow."""
with Session(db.engine) as session:
# Get all comments with eager loading
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
.order_by(desc(WorkflowComment.created_at))
)
comments = session.scalars(stmt).all()
return comments
@staticmethod
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session = None) -> WorkflowComment:
"""Get a specific comment."""
def _get_comment(session: Session) -> WorkflowComment:
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
return comment
if session is not None:
return _get_comment(session)
else:
with Session(db.engine, expire_on_commit=False) as session:
return _get_comment(session)
@staticmethod
def create_comment(
tenant_id: str,
app_id: str,
created_by: str,
content: str,
position_x: float,
position_y: float,
mentioned_user_ids: Optional[list[str]] = None,
) -> WorkflowComment:
"""Create a new workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine) as session:
comment = WorkflowComment(
tenant_id=tenant_id,
app_id=app_id,
position_x=position_x,
position_y=position_y,
content=content,
created_by=created_by,
)
session.add(comment)
session.flush() # Get the comment ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention, not reply mention
mentioned_user_id=user_id,
)
session.add(mention)
session.commit()
# Return only what we need - id and created_at
return {"id": comment.id, "created_at": comment.created_at}
@staticmethod
def update_comment(
tenant_id: str,
app_id: str,
comment_id: str,
user_id: str,
content: str,
position_x: Optional[float] = None,
position_y: Optional[float] = None,
mentioned_user_ids: Optional[list[str]] = None,
) -> dict:
"""Update a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Get comment with validation
stmt = select(WorkflowComment).where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
# Only the creator can update the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can update it")
# Update comment fields
comment.content = content
if position_x is not None:
comment.position_x = position_x
if position_y is not None:
comment.position_y = position_y
# Update mentions - first remove existing mentions for this comment only (not replies)
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(
WorkflowCommentMention.comment_id == comment.id,
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add new mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention
mentioned_user_id=user_id_str,
)
session.add(mention)
session.commit()
return {"id": comment.id, "updated_at": comment.updated_at}
@staticmethod
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
"""Delete a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
# Only the creator can delete the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can delete it")
# Delete associated mentions (both comment and reply mentions)
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
).all()
for mention in mentions:
session.delete(mention)
# Delete associated replies
replies = session.scalars(
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
).all()
for reply in replies:
session.delete(reply)
session.delete(comment)
session.commit()
@staticmethod
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
"""Resolve a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
if comment.resolved:
return comment
comment.resolved = True
comment.resolved_at = naive_utc_now()
comment.resolved_by = user_id
session.commit()
return comment
@staticmethod
def create_reply(
comment_id: str, content: str, created_by: str, mentioned_user_ids: Optional[list[str]] = None
) -> dict:
"""Add a reply to a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Check if comment exists
comment = session.get(WorkflowComment, comment_id)
if not comment:
raise NotFound("Comment not found")
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
session.add(reply)
session.flush() # Get the reply ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
# Create mention linking to specific reply
mention = WorkflowCommentMention(
comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id
)
session.add(mention)
session.commit()
return {"id": reply.id, "created_at": reply.created_at}
@staticmethod
def update_reply(
reply_id: str, user_id: str, content: str, mentioned_user_ids: Optional[list[str]] = None
) -> WorkflowCommentReply:
"""Update a comment reply."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can update the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can update it")
reply.content = content
# Update mentions - first remove existing mentions for this reply
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
)
session.add(mention)
session.commit()
session.refresh(reply) # Refresh to get updated timestamp
return {"id": reply.id, "updated_at": reply.updated_at}
@staticmethod
def delete_reply(reply_id: str, user_id: str) -> None:
"""Delete a comment reply."""
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can delete the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can delete it")
# Delete associated mentions first
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
).all()
for mention in mentions:
session.delete(mention)
session.delete(reply)
session.commit()
@staticmethod
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
"""Validate that a comment belongs to the specified tenant and app."""
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)

View File

@ -11,6 +11,7 @@ from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.memory.entities import MemoryBlockSpec, MemoryCreatedBy, MemoryScope
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
@ -196,15 +197,18 @@ class WorkflowService:
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
force_upload: bool = False,
) -> Workflow:
"""
Sync draft workflow
:param force_upload: Skip hash validation when True (for restore operations)
:raises WorkflowHashNotEqualError
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if workflow and workflow.unique_hash != unique_hash:
if workflow and workflow.unique_hash != unique_hash and not force_upload:
raise WorkflowHashNotEqualError()
# validate features structure
@ -223,6 +227,7 @@ class WorkflowService:
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
workflow.memory_blocks = memory_blocks or []
db.session.add(workflow)
# update draft workflow if found
else:
@ -232,6 +237,7 @@ class WorkflowService:
workflow.updated_at = naive_utc_now()
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.memory_blocks = memory_blocks or []
# commit db session changes
db.session.commit()
@ -242,6 +248,78 @@ class WorkflowService:
# return draft workflow
return workflow
def update_draft_workflow_environment_variables(
self,
*,
app_model: App,
environment_variables: Sequence[Variable],
account: Account,
):
"""
Update draft workflow environment variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.environment_variables = environment_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_conversation_variables(
self,
*,
app_model: App,
conversation_variables: Sequence[Variable],
account: Account,
):
"""
Update draft workflow conversation variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.conversation_variables = conversation_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_features(
self,
*,
app_model: App,
features: dict,
account: Account,
):
"""
Update draft workflow features
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def publish_workflow(
self,
*,
@ -279,6 +357,7 @@ class WorkflowService:
marked_name=marked_name,
marked_comment=marked_comment,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
memory_blocks=draft_workflow.memory_blocks,
features=draft_workflow.features,
)
@ -635,17 +714,10 @@ class WorkflowService:
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
)
# init variable pool
variable_pool = _setup_variable_pool(
query=query,
files=files or [],
user_id=account.id,
user_inputs=user_inputs,
workflow=draft_workflow,
# NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables.
conversation_variables=[],
node_type=node_type,
conversation_id=conversation_id,
)
variable_pool = _setup_variable_pool(query=query, files=files or [], user_id=account.id,
user_inputs=user_inputs, workflow=draft_workflow,
node_type=node_type, conversation_id=conversation_id,
conversation_variables=[], is_draft=True)
else:
variable_pool = VariablePool(
@ -994,6 +1066,7 @@ def _setup_variable_pool(
node_type: NodeType,
conversation_id: str,
conversation_variables: list[Variable],
is_draft: bool
):
# Only inject system variables for START node type.
if node_type == NodeType.START:
@ -1012,7 +1085,6 @@ def _setup_variable_pool(
system_variable.dialogue_count = 1
else:
system_variable = SystemVariable.empty()
# init variable pool
variable_pool = VariablePool(
system_variables=system_variable,
@ -1021,6 +1093,12 @@ def _setup_variable_pool(
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), #
memory_blocks=_fetch_memory_blocks(
workflow,
MemoryCreatedBy(account_id=user_id),
conversation_id,
is_draft=is_draft
),
)
return variable_pool
@ -1057,3 +1135,30 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
return build_from_mappings(mappings=value, tenant_id=tenant_id)
else:
raise Exception("unreachable")
def _fetch_memory_blocks(
workflow: Workflow,
created_by: MemoryCreatedBy,
conversation_id: str,
is_draft: bool
) -> Mapping[str, str]:
memory_blocks = {}
memory_block_specs = workflow.memory_blocks
from services.chatflow_memory_service import ChatflowMemoryService
memories = ChatflowMemoryService.get_memories_by_specs(
memory_block_specs=memory_block_specs,
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
node_id=None,
conversation_id=conversation_id,
is_draft=is_draft,
created_by=created_by,
)
for memory in memories:
if memory.spec.scope == MemoryScope.APP:
memory_blocks[memory.spec.id] = memory.value
else: # NODE scope
memory_blocks[f"{memory.node_id}.{memory.spec.id}"] = memory.value
return memory_blocks

View File

@ -46,5 +46,17 @@ class WorkspaceService:
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
if paid_pool:
tenant_info["trial_credits"] = paid_pool.quota_limit
tenant_info["trial_credits_used"] = paid_pool.quota_used
else:
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
if trial_pool:
tenant_info["trial_credits"] = trial_pool.quota_limit
tenant_info["trial_credits_used"] = trial_pool.quota_used
return tenant_info

4803
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,14 @@ server {
include proxy.conf;
}
location /socket.io/ {
proxy_pass http://api:5001;
include proxy.conf;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_cache_bypass $http_upgrade;
}
location /v1 {
proxy_pass http://api:5001;
include proxy.conf;

View File

@ -5,7 +5,7 @@ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Port $server_port;
proxy_http_version 1.1;
proxy_set_header Connection "";
# proxy_set_header Connection "";
proxy_buffering off;
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT};
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT};

View File

@ -1,6 +1,6 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import React, { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import AppCard from '@/app/components/app/overview/app-card'
@ -19,6 +19,8 @@ import { asyncRunSafe } from '@/utils'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import type { IAppCardProps } from '@/app/components/app/overview/app-card'
import { useStore as useAppStore } from '@/app/components/app/store'
import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
export type ICardViewProps = {
appId: string
@ -47,15 +49,44 @@ const CardView: FC<ICardViewProps> = ({ appId, isInPanel, className }) => {
message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully')
if (type === 'success')
if (type === 'success') {
updateAppDetail()
// Emit collaboration event to notify other clients of app state changes
const socket = webSocketClient.getSocket(appId)
if (socket) {
socket.emit('collaboration_event', {
type: 'appStateUpdate',
data: { timestamp: Date.now() },
timestamp: Date.now(),
})
}
}
notify({
type,
message: t(`common.actionMsg.${message}`),
})
}
// Listen for collaborative app state updates from other clients
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onAppStateUpdate(async (update: any) => {
try {
console.log('Received app state update from collaboration:', update)
// Update app detail when other clients modify app state
await updateAppDetail()
}
catch (error) {
console.error('app state update failed:', error)
}
})
return unsubscribe
}, [appId])
const onChangeSiteStatus = async (value: boolean) => {
const [err] = await asyncRunSafe<App>(
updateAppSiteStatus({

View File

@ -32,6 +32,8 @@ import { useGlobalPublicStore } from '@/context/global-public-context'
import { formatTime } from '@/utils/time'
import { useGetUserCanAccessApp } from '@/service/access-control'
import dynamic from 'next/dynamic'
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
import type { WorkflowOnlineUser } from '@/models/app'
const EditAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), {
ssr: false,
@ -55,9 +57,10 @@ const AccessControl = dynamic(() => import('@/app/components/app/app-access-cont
export type AppCardProps = {
app: App
onRefresh?: () => void
onlineUsers?: WorkflowOnlineUser[]
}
const AppCard = ({ app, onRefresh }: AppCardProps) => {
const AppCard = ({ app, onRefresh, onlineUsers = [] }: AppCardProps) => {
const { t } = useTranslation()
const { notify } = useContext(ToastContext)
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
@ -333,6 +336,19 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
return `${t('datasetDocuments.segment.editedAt')} ${timeText}`
}, [app.updated_at, app.created_at])
const onlineUserAvatars = useMemo(() => {
if (!onlineUsers.length)
return []
return onlineUsers
.map(user => ({
id: user.user_id || user.sid || '',
name: user.username || 'User',
avatar_url: user.avatar || undefined,
}))
.filter(user => !!user.id)
}, [onlineUsers])
return (
<>
<div
@ -377,6 +393,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
<RiVerifiedBadgeLine className='h-4 w-4 text-text-quaternary' />
</Tooltip>}
</div>
<div>
{onlineUserAvatars.length > 0 && (
<UserAvatarList users={onlineUserAvatars} maxVisible={3} size={20} />
)}
</div>
</div>
<div className='title-wrapper h-[90px] px-[14px] text-xs leading-normal text-text-tertiary'>
<div

View File

@ -1,10 +1,11 @@
'use client'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import {
useRouter,
} from 'next/navigation'
import useSWRInfinite from 'swr/infinite'
import useSWR from 'swr'
import { useTranslation } from 'react-i18next'
import { useDebounceFn } from 'ahooks'
import {
@ -19,8 +20,8 @@ import AppCard from './app-card'
import NewAppCard from './new-app-card'
import useAppsQueryState from './hooks/use-apps-query-state'
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
import type { AppListResponse } from '@/models/app'
import { fetchAppList } from '@/service/apps'
import type { AppListResponse, WorkflowOnlineUser } from '@/models/app'
import { fetchAppList, fetchWorkflowOnlineUsers } from '@/service/apps'
import { useAppContext } from '@/context/app-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { CheckModal } from '@/hooks/use-pay'
@ -112,6 +113,37 @@ const List = () => {
},
)
const apps = useMemo(() => data?.flatMap(page => page.data) ?? [], [data])
const workflowIds = useMemo(() => {
const ids = new Set<string>()
apps.forEach((appItem) => {
const workflowId = appItem.id
if (!workflowId)
return
if (appItem.mode === 'workflow' || appItem.mode === 'advanced-chat')
ids.add(workflowId)
})
return Array.from(ids)
}, [apps])
const { data: onlineUsersByWorkflow, mutate: refreshOnlineUsers } = useSWR<Record<string, WorkflowOnlineUser[]>>(
workflowIds.length ? { workflowIds } : null,
fetchWorkflowOnlineUsers,
)
useEffect(() => {
if (!workflowIds.length)
return
const timer = window.setInterval(() => {
refreshOnlineUsers()
}, 10000)
return () => window.clearInterval(timer)
}, [workflowIds.join(','), refreshOnlineUsers])
const anchorRef = useRef<HTMLDivElement>(null)
const options = [
{ value: 'all', text: t('app.types.all'), icon: <RiApps2Line className='mr-1 h-[14px] w-[14px]' /> },
@ -213,7 +245,12 @@ const List = () => {
{isCurrentWorkspaceEditor
&& <NewAppCard ref={newAppCardRef} onSuccess={mutate} selectedAppType={activeTab} />}
{data.map(({ data: apps }) => apps.map(app => (
<AppCard key={app.id} app={app} onRefresh={mutate} />
<AppCard
key={app.id}
app={app}
onRefresh={mutate}
onlineUsers={onlineUsersByWorkflow?.[app.id] ?? []}
/>
)))}
</div>
: <div className='relative grid grow grid-cols-1 content-start gap-4 overflow-hidden px-12 pt-2 sm:grid-cols-1 md:grid-cols-2 xl:grid-cols-4 2xl:grid-cols-5 2k:grid-cols-6'>

View File

@ -9,6 +9,7 @@ export type AvatarProps = {
className?: string
textClassName?: string
onError?: (x: boolean) => void
backgroundColor?: string
}
const Avatar = ({
name,
@ -17,9 +18,18 @@ const Avatar = ({
className,
textClassName,
onError,
backgroundColor,
}: AvatarProps) => {
const avatarClassName = 'shrink-0 flex items-center rounded-full bg-primary-600'
const style = { width: `${size}px`, height: `${size}px`, fontSize: `${size}px`, lineHeight: `${size}px` }
const avatarClassName = backgroundColor
? 'shrink-0 flex items-center rounded-full'
: 'shrink-0 flex items-center rounded-full bg-primary-600'
const style = {
width: `${size}px`,
height: `${size}px`,
fontSize: `${size}px`,
lineHeight: `${size}px`,
...(backgroundColor && !avatar ? { backgroundColor } : {}),
}
const [imgError, setImgError] = useState(false)
const handleError = () => {
@ -35,14 +45,18 @@ const Avatar = ({
if (avatar && !imgError) {
return (
<img
<span
className={cn(avatarClassName, className)}
style={style}
alt={name}
src={avatar}
onError={handleError}
onLoad={() => onError?.(false)}
/>
>
<img
className='h-full w-full rounded-full object-cover'
alt={name}
src={avatar}
onError={handleError}
onLoad={() => onError?.(false)}
/>
</span>
)
}

View File

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="12" viewBox="0 0 14 12" fill="none">
<path d="M12.3334 4C12.3334 2.52725 11.1395 1.33333 9.66671 1.33333H4.33337C2.86062 1.33333 1.66671 2.52724 1.66671 4V10.6667H9.66671C11.1395 10.6667 12.3334 9.47274 12.3334 8V4ZM7.66671 6.66667V8H4.33337V6.66667H7.66671ZM9.66671 4V5.33333H4.33337V4H9.66671ZM13.6667 8C13.6667 10.2091 11.8758 12 9.66671 12H0.333374V4C0.333374 1.79086 2.12424 0 4.33337 0H9.66671C11.8758 0 13.6667 1.79086 13.6667 4V8Z" fill="currentColor"/>
</svg>

After

Width:  |  Height:  |  Size: 527 B

View File

@ -0,0 +1,26 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"xmlns": "http://www.w3.org/2000/svg",
"width": "14",
"height": "12",
"viewBox": "0 0 14 12",
"fill": "none"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M12.3334 4C12.3334 2.52725 11.1395 1.33333 9.66671 1.33333H4.33337C2.86062 1.33333 1.66671 2.52724 1.66671 4V10.6667H9.66671C11.1395 10.6667 12.3334 9.47274 12.3334 8V4ZM7.66671 6.66667V8H4.33337V6.66667H7.66671ZM9.66671 4V5.33333H4.33337V4H9.66671ZM13.6667 8C13.6667 10.2091 11.8758 12 9.66671 12H0.333374V4C0.333374 1.79086 2.12424 0 4.33337 0H9.66671C11.8758 0 13.6667 1.79086 13.6667 4V8Z",
"fill": "currentColor"
},
"children": []
}
]
},
"name": "Comment"
}

View File

@ -0,0 +1,20 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './Comment.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>;
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'Comment'
export default Icon

View File

@ -1,4 +1,5 @@
export { default as Icon3Dots } from './Icon3Dots'
export { default as Comment } from './Comment'
export { default as DefaultToolIcon } from './DefaultToolIcon'
export { default as Message3Fill } from './Message3Fill'
export { default as RowStruct } from './RowStruct'

View File

@ -2,6 +2,7 @@
import type { FC } from 'react'
import React, { useEffect } from 'react'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import type {
EditorState,
} from 'lexical'
@ -80,6 +81,29 @@ import {
import { useEventEmitterContextContext } from '@/context/event-emitter'
import cn from '@/utils/classnames'
const ValueSyncPlugin: FC<{ value?: string }> = ({ value }) => {
const [editor] = useLexicalComposerContext()
useEffect(() => {
if (value === undefined)
return
const incomingValue = value ?? ''
const shouldUpdate = editor.getEditorState().read(() => {
const currentText = $getRoot().getChildren().map(node => node.getTextContent()).join('\n')
return currentText !== incomingValue
})
if (!shouldUpdate)
return
const editorState = editor.parseEditorState(textToEditorState(incomingValue))
editor.setEditorState(editorState)
}, [editor, value])
return null
}
export type PromptEditorProps = {
instanceId?: string
compact?: boolean
@ -293,6 +317,7 @@ const PromptEditor: FC<PromptEditorProps> = ({
<VariableValueBlock />
)
}
<ValueSyncPlugin value={value} />
<OnChangePlugin onChange={handleEditorChange} />
<OnBlurBlock onBlur={onBlur} onFocus={onFocus} />
<UpdateBlock instanceId={instanceId} />

View File

@ -0,0 +1,77 @@
import type { FC } from 'react'
import { memo } from 'react'
import { getUserColor } from '@/app/components/workflow/collaboration/utils/user-color'
import { useAppContext } from '@/context/app-context'
import Avatar from '@/app/components/base/avatar'
type User = {
id: string
name: string
avatar_url?: string | null
}
type UserAvatarListProps = {
users: User[]
maxVisible?: number
size?: number
className?: string
showCount?: boolean
}
export const UserAvatarList: FC<UserAvatarListProps> = memo(({
users,
maxVisible = 3,
size = 24,
className = '',
showCount = true,
}) => {
const { userProfile } = useAppContext()
if (!users.length) return null
const shouldShowCount = showCount && users.length > maxVisible
const actualMaxVisible = shouldShowCount ? Math.max(1, maxVisible - 1) : maxVisible
const visibleUsers = users.slice(0, actualMaxVisible)
const remainingCount = users.length - actualMaxVisible
const currentUserId = userProfile?.id
return (
<div className={`flex items-center -space-x-1 ${className}`}>
{visibleUsers.map((user, index) => {
const isCurrentUser = user.id === currentUserId
const userColor = isCurrentUser ? undefined : getUserColor(user.id)
return (
<div
key={`${user.id}-${index}`}
className='relative'
style={{ zIndex: visibleUsers.length - index }}
>
<Avatar
name={user.name}
avatar={user.avatar_url || null}
size={size}
className='ring-2 ring-components-panel-bg'
backgroundColor={userColor}
/>
</div>
)
},
)}
{shouldShowCount && remainingCount > 0 && (
<div
className={'flex items-center justify-center rounded-full bg-gray-500 text-[10px] leading-none text-white ring-2 ring-components-panel-bg'}
style={{
zIndex: 0,
width: size,
height: size,
}}
>
+{remainingCount}
</div>
)}
</div>
)
})
UserAvatarList.displayName = 'UserAvatarList'

View File

@ -26,6 +26,8 @@ import {
import { BlockEnum } from '@/app/components/workflow/types'
import cn from '@/utils/classnames'
import { fetchAppDetail } from '@/service/apps'
import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
export type IAppCardProps = {
appInfo: AppDetailResponse & Partial<AppSSO>
@ -90,6 +92,19 @@ function MCPServiceCard({
const onGenCode = async () => {
await refreshMCPServerCode(detail?.id || '')
invalidateMCPServerDetail(appId)
// Emit collaboration event to notify other clients of MCP server changes
const socket = webSocketClient.getSocket(appId)
if (socket) {
socket.emit('collaboration_event', {
type: 'mcpServerUpdate',
data: {
action: 'codeRegenerated',
timestamp: Date.now(),
},
timestamp: Date.now(),
})
}
}
const onChangeStatus = async (state: boolean) => {
@ -119,6 +134,20 @@ function MCPServiceCard({
})
invalidateMCPServerDetail(appId)
}
// Emit collaboration event to notify other clients of MCP server status change
const socket = webSocketClient.getSocket(appId)
if (socket) {
socket.emit('collaboration_event', {
type: 'mcpServerUpdate',
data: {
action: 'statusChanged',
status: state ? 'active' : 'inactive',
timestamp: Date.now(),
},
timestamp: Date.now(),
})
}
}
const handleServerModalHide = () => {
@ -131,6 +160,23 @@ function MCPServiceCard({
setActivated(serverActivated)
}, [serverActivated])
// Listen for collaborative MCP server updates from other clients
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onMcpServerUpdate(async (update: any) => {
try {
console.log('Received MCP server update from collaboration:', update)
invalidateMCPServerDetail(appId)
}
catch (error) {
console.error('MCP server update failed:', error)
}
})
return unsubscribe
}, [appId, invalidateMCPServerDetail])
if (!currentWorkflow && isAdvancedApp)
return null

View File

@ -1,11 +1,18 @@
import {
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from 'react'
import { useFeaturesStore } from '@/app/components/base/features/hooks'
import type { Features as FeaturesData } from '@/app/components/base/features/types'
import { SupportUploadFileTypes } from '@/app/components/workflow/types'
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
import { WorkflowWithInnerContext } from '@/app/components/workflow'
import type { WorkflowProps } from '@/app/components/workflow'
import WorkflowChildren from './workflow-children'
import {
useAvailableNodesMetaData,
useConfigsMap,
@ -18,7 +25,12 @@ import {
useWorkflowRun,
useWorkflowStartRun,
} from '../hooks'
import { useWorkflowStore } from '@/app/components/workflow/store'
import { useWorkflowUpdate } from '@/app/components/workflow/hooks/use-workflow-interactions'
import { useStore, useWorkflowStore } from '@/app/components/workflow/store'
import { useCollaboration } from '@/app/components/workflow/collaboration'
import { collaborationManager } from '@/app/components/workflow/collaboration'
import { fetchWorkflowDraft } from '@/service/workflow'
import { useReactFlow, useStoreApi } from 'reactflow'
type WorkflowMainProps = Pick<WorkflowProps, 'nodes' | 'edges' | 'viewport'>
const WorkflowMain = ({
@ -28,6 +40,31 @@ const WorkflowMain = ({
}: WorkflowMainProps) => {
const featuresStore = useFeaturesStore()
const workflowStore = useWorkflowStore()
const appId = useStore(s => s.appId)
const containerRef = useRef<HTMLDivElement>(null)
const reactFlow = useReactFlow()
const store = useStoreApi()
const { startCursorTracking, stopCursorTracking, onlineUsers, cursors, isConnected } = useCollaboration(appId || '', store)
const [myUserId, setMyUserId] = useState<string | null>(null)
useEffect(() => {
if (isConnected)
setMyUserId('current-user')
}, [isConnected])
const filteredCursors = Object.fromEntries(
Object.entries(cursors).filter(([userId]) => userId !== myUserId),
)
useEffect(() => {
if (containerRef.current)
startCursorTracking(containerRef as React.RefObject<HTMLElement>, reactFlow)
return () => {
stopCursorTracking()
}
}, [startCursorTracking, stopCursorTracking, reactFlow])
const handleWorkflowDataUpdate = useCallback((payload: any) => {
const {
@ -38,7 +75,33 @@ const WorkflowMain = ({
if (features && featuresStore) {
const { setFeatures } = featuresStore.getState()
setFeatures(features)
const transformedFeatures: FeaturesData = {
file: {
image: {
enabled: !!features.file_upload?.image?.enabled,
number_limits: features.file_upload?.image?.number_limits || 3,
transfer_methods: features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
},
enabled: !!(features.file_upload?.enabled || features.file_upload?.image?.enabled),
allowed_file_types: features.file_upload?.allowed_file_types || [SupportUploadFileTypes.image],
allowed_file_extensions: features.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`),
allowed_file_upload_methods: features.file_upload?.allowed_file_upload_methods || features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
number_limits: features.file_upload?.number_limits || features.file_upload?.image?.number_limits || 3,
},
opening: {
enabled: !!features.opening_statement,
opening_statement: features.opening_statement,
suggested_questions: features.suggested_questions,
},
suggested: features.suggested_questions_after_answer || { enabled: false },
speech2text: features.speech_to_text || { enabled: false },
text2speech: features.text_to_speech || { enabled: false },
citation: features.retriever_resource || { enabled: false },
moderation: features.sensitive_word_avoidance || { enabled: false },
annotationReply: features.annotation_reply || { enabled: false },
}
setFeatures(transformedFeatures)
}
if (conversation_variables) {
const { setConversationVariables } = workflowStore.getState()
@ -55,6 +118,7 @@ const WorkflowMain = ({
syncWorkflowDraftWhenPageClose,
} = useNodesSyncDraft()
const { handleRefreshWorkflowDraft } = useWorkflowRefreshDraft()
const { handleUpdateWorkflowCanvas } = useWorkflowUpdate()
const {
handleBackupDraft,
handleLoadBackupDraft,
@ -62,6 +126,63 @@ const WorkflowMain = ({
handleRun,
handleStopRun,
} = useWorkflowRun()
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onVarsAndFeaturesUpdate(async (update: any) => {
try {
const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`)
handleWorkflowDataUpdate(response)
}
catch (error) {
console.error('workflow vars and features update failed:', error)
}
})
return unsubscribe
}, [appId, handleWorkflowDataUpdate])
// Listen for workflow updates from other users
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onWorkflowUpdate(async () => {
console.log('Received workflow update from collaborator, fetching latest workflow data')
try {
const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`)
// Handle features, variables etc.
handleWorkflowDataUpdate(response)
// Update workflow canvas (nodes, edges, viewport)
if (response.graph) {
handleUpdateWorkflowCanvas({
nodes: response.graph.nodes || [],
edges: response.graph.edges || [],
viewport: response.graph.viewport || { x: 0, y: 0, zoom: 1 },
})
}
}
catch (error) {
console.error('Failed to fetch updated workflow:', error)
}
})
return unsubscribe
}, [appId, handleWorkflowDataUpdate, handleUpdateWorkflowCanvas])
// Listen for sync requests from other users (only processed by leader)
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onSyncRequest(() => {
console.log('Leader received sync request, performing sync')
doSyncWorkflowDraft()
})
return unsubscribe
}, [appId, doSyncWorkflowDraft])
const {
handleStartWorkflowRun,
handleWorkflowStartRunInChatflow,
@ -75,6 +196,7 @@ const WorkflowMain = ({
} = useDSL()
const configsMap = useConfigsMap()
const { fetchInspectVars } = useSetWorkflowVarsWithValue({
...configsMap,
})
@ -164,15 +286,23 @@ const WorkflowMain = ({
])
return (
<WorkflowWithInnerContext
nodes={nodes}
edges={edges}
viewport={viewport}
onWorkflowDataUpdate={handleWorkflowDataUpdate}
hooksStore={hooksStore as any}
<div
ref={containerRef}
className="relative h-full w-full"
>
<WorkflowChildren />
</WorkflowWithInnerContext>
<WorkflowWithInnerContext
nodes={nodes}
edges={edges}
viewport={viewport}
onWorkflowDataUpdate={handleWorkflowDataUpdate}
hooksStore={hooksStore as any}
cursors={filteredCursors}
myUserId={myUserId}
onlineUsers={onlineUsers}
>
<WorkflowChildren />
</WorkflowWithInnerContext>
</div>
)
}

View File

@ -7,6 +7,7 @@ import { useStore } from '@/app/components/workflow/store'
import {
useIsChatMode,
} from '../hooks'
import CommentsPanel from '@/app/components/workflow/panel/comments-panel'
import { useStore as useAppStore } from '@/app/components/app/store'
import type { PanelProps } from '@/app/components/workflow/panel'
import Panel from '@/app/components/workflow/panel'
@ -67,6 +68,7 @@ const WorkflowPanelOnRight = () => {
const showDebugAndPreviewPanel = useStore(s => s.showDebugAndPreviewPanel)
const showChatVariablePanel = useStore(s => s.showChatVariablePanel)
const showGlobalVariablePanel = useStore(s => s.showGlobalVariablePanel)
const controlMode = useStore(s => s.controlMode)
return (
<>
@ -100,6 +102,7 @@ const WorkflowPanelOnRight = () => {
<GlobalVariablePanel />
)
}
{controlMode === 'comment' && <CommentsPanel />}
</>
)
}

View File

@ -13,6 +13,7 @@ import { syncWorkflowDraft } from '@/service/workflow'
import { useFeaturesStore } from '@/app/components/base/features/hooks'
import { API_PREFIX } from '@/config'
import { useWorkflowRefreshDraft } from '.'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
export const useNodesSyncDraft = () => {
const store = useStoreApi()
@ -85,6 +86,7 @@ export const useNodesSyncDraft = () => {
environment_variables: environmentVariables,
conversation_variables: conversationVariables,
hash: syncWorkflowDraftHash,
_is_collaborative: true,
},
}
}
@ -93,9 +95,20 @@ export const useNodesSyncDraft = () => {
const syncWorkflowDraftWhenPageClose = useCallback(() => {
if (getNodesReadOnly())
return
// Check leader status at sync time
const currentIsLeader = collaborationManager.getIsLeader()
// Only allow leader to sync data
if (!currentIsLeader) {
console.log('Not leader, skipping sync on page close')
return
}
const postParams = getPostParams()
if (postParams) {
console.log('Leader syncing workflow draft on page close')
navigator.sendBeacon(
`${API_PREFIX}/apps/${params.appId}/workflows/draft?_token=${localStorage.getItem('console_token')}`,
JSON.stringify(postParams.params),
@ -110,9 +123,23 @@ export const useNodesSyncDraft = () => {
onError?: () => void
onSettled?: () => void
},
forceUpload?: boolean,
) => {
if (getNodesReadOnly())
return
// Check leader status at sync time
const currentIsLeader = collaborationManager.getIsLeader()
// If not leader and not forcing upload, request the leader to sync
if (!currentIsLeader && !forceUpload) {
console.log('Not leader, requesting leader to sync workflow draft')
collaborationManager.emitSyncRequest()
callback?.onSettled?.()
return
}
console.log(forceUpload ? 'Force uploading workflow draft' : 'Leader performing workflow draft sync')
const postParams = getPostParams()
if (postParams) {
@ -120,17 +147,30 @@ export const useNodesSyncDraft = () => {
setSyncWorkflowDraftHash,
setDraftUpdatedAt,
} = workflowStore.getState()
// Add force_upload parameter if needed
const finalParams = {
...postParams.params,
...(forceUpload && { force_upload: true }),
}
try {
const res = await syncWorkflowDraft(postParams)
const res = await syncWorkflowDraft({
url: postParams.url,
params: finalParams,
})
setSyncWorkflowDraftHash(res.hash)
setDraftUpdatedAt(res.updated_at)
callback?.onSuccess?.()
}
catch (error: any) {
console.error('Leader failed to sync workflow draft:', error)
if (error && error.json && !error.bodyUsed) {
error.json().then((err: any) => {
if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError)
if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) {
console.error('draft_workflow_not_sync', err)
handleRefreshWorkflowDraft()
}
})
}
callback?.onError?.()

View File

@ -25,6 +25,7 @@ import {
import type { InjectWorkflowStoreSliceFn } from '@/app/components/workflow/store'
import { createWorkflowSlice } from './store/workflow/workflow-slice'
import WorkflowAppMain from './components/workflow-main'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
const WorkflowAppWithAdditionalContext = () => {
const {
@ -35,15 +36,20 @@ const WorkflowAppWithAdditionalContext = () => {
const { isLoadingCurrentWorkspace, currentWorkspace } = useAppContext()
const nodesData = useMemo(() => {
if (data)
return initialNodes(data.graph.nodes, data.graph.edges)
if (data) {
const processedNodes = initialNodes(data.graph.nodes, data.graph.edges)
collaborationManager.setNodes([], processedNodes)
return processedNodes
}
return []
}, [data])
const edgesData = useMemo(() => {
if (data)
return initialEdges(data.graph.edges, data.graph.nodes)
const edgesData = useMemo(() => {
if (data) {
const processedEdges = initialEdges(data.graph.edges, data.graph.nodes)
collaborationManager.setEdges([], processedEdges)
return processedEdges
}
return []
}, [data])

View File

@ -4,7 +4,6 @@ import {
import produce from 'immer'
import {
useReactFlow,
useStoreApi,
useViewport,
} from 'reactflow'
import { useEventListener } from 'ahooks'
@ -19,9 +18,9 @@ import CustomNode from './nodes'
import CustomNoteNode from './note-node'
import { CUSTOM_NOTE_NODE } from './note-node/constants'
import { BlockEnum } from './types'
import { useCollaborativeWorkflow } from '@/app/components/workflow/hooks/use-collaborative-workflow'
const CandidateNode = () => {
const store = useStoreApi()
const reactflow = useReactFlow()
const workflowStore = useWorkflowStore()
const candidateNode = useStore(s => s.candidateNode)
@ -29,18 +28,15 @@ const CandidateNode = () => {
const { zoom } = useViewport()
const { handleNodeSelect } = useNodesInteractions()
const { saveStateToHistory } = useWorkflowHistory()
const collaborativeWorkflow = useCollaborativeWorkflow()
useEventListener('click', (e) => {
const { candidateNode, mousePosition } = workflowStore.getState()
if (candidateNode) {
e.preventDefault()
const {
getNodes,
setNodes,
} = store.getState()
const { nodes, setNodes } = collaborativeWorkflow.getState()
const { screenToFlowPosition } = reactflow
const nodes = getNodes()
const { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
const newNodes = produce(nodes, (draft) => {
draft.push({

View File

@ -0,0 +1,78 @@
import type { FC } from 'react'
import { useViewport } from 'reactflow'
import type { CursorPosition, OnlineUser } from '@/app/components/workflow/collaboration/types'
import { getUserColor } from '../utils/user-color'
type UserCursorsProps = {
cursors: Record<string, CursorPosition>
myUserId: string | null
onlineUsers: OnlineUser[]
}
const UserCursors: FC<UserCursorsProps> = ({
cursors,
myUserId,
onlineUsers,
}) => {
const viewport = useViewport()
const convertToScreenCoordinates = (cursor: CursorPosition) => {
// Convert world coordinates to screen coordinates using current viewport
const screenX = cursor.x * viewport.zoom + viewport.x
const screenY = cursor.y * viewport.zoom + viewport.y
return { x: screenX, y: screenY }
}
return (
<>
{Object.entries(cursors || {}).map(([userId, cursor]) => {
if (userId === myUserId)
return null
const userInfo = onlineUsers.find(user => user.user_id === userId)
const userName = userInfo?.username || `User ${userId.slice(-4)}`
const userColor = getUserColor(userId)
const screenPos = convertToScreenCoordinates(cursor)
return (
<div
key={userId}
className="pointer-events-none absolute z-[8] transition-all duration-150 ease-out"
style={{
left: screenPos.x,
top: screenPos.y,
}}
>
<svg
width="20"
height="20"
viewBox="0 0 20 20"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="drop-shadow-md"
>
<path
d="M5 3L5 15L8 11.5L11 16L13 15L10 10.5L14 10.5L5 3Z"
fill={userColor}
stroke="white"
strokeWidth="1.5"
strokeLinejoin="round"
/>
</svg>
<div
className="absolute left-4 top-4 max-w-[120px] overflow-hidden text-ellipsis whitespace-nowrap rounded px-1.5 py-0.5 text-[11px] font-medium text-white shadow-sm"
style={{
backgroundColor: userColor,
}}
>
{userName}
</div>
</div>
)
})}
</>
)
}
export default UserCursors

View File

@ -0,0 +1,924 @@
import { LoroDoc, UndoManager } from 'loro-crdt'
import { isEqual } from 'lodash-es'
import { webSocketClient } from './websocket-manager'
import { CRDTProvider } from './crdt-provider'
import { EventEmitter } from './event-emitter'
import type { Edge, Node } from '../../types'
import type {
CollaborationState,
CursorPosition,
NodePanelPresenceMap,
NodePanelPresenceUser,
OnlineUser,
} from '../types/collaboration'
type NodePanelPresenceEventData = {
nodeId: string
action: 'open' | 'close'
user: NodePanelPresenceUser
clientId: string
timestamp?: number
}
export class CollaborationManager {
private doc: LoroDoc | null = null
private undoManager: UndoManager | null = null
private provider: CRDTProvider | null = null
private nodesMap: any = null
private edgesMap: any = null
private eventEmitter = new EventEmitter()
private currentAppId: string | null = null
private reactFlowStore: any = null
private isLeader = false
private leaderId: string | null = null
private cursors: Record<string, CursorPosition> = {}
private nodePanelPresence: NodePanelPresenceMap = {}
private activeConnections = new Set<string>()
private isUndoRedoInProgress = false
private getNodePanelPresenceSnapshot(): NodePanelPresenceMap {
const snapshot: NodePanelPresenceMap = {}
Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => {
snapshot[nodeId] = { ...viewers }
})
return snapshot
}
private applyNodePanelPresenceUpdate(update: NodePanelPresenceEventData): void {
const { nodeId, action, clientId, user, timestamp } = update
if (action === 'open') {
// ensure a client only appears on a single node at a time
Object.entries(this.nodePanelPresence).forEach(([id, viewers]) => {
if (viewers[clientId]) {
delete viewers[clientId]
if (Object.keys(viewers).length === 0)
delete this.nodePanelPresence[id]
}
})
if (!this.nodePanelPresence[nodeId])
this.nodePanelPresence[nodeId] = {}
this.nodePanelPresence[nodeId][clientId] = {
...user,
clientId,
timestamp: timestamp || Date.now(),
}
}
else {
const viewers = this.nodePanelPresence[nodeId]
if (viewers) {
delete viewers[clientId]
if (Object.keys(viewers).length === 0)
delete this.nodePanelPresence[nodeId]
}
}
this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot())
}
private cleanupNodePanelPresence(activeClientIds: Set<string>, activeUserIds: Set<string>): void {
let hasChanges = false
Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => {
Object.keys(viewers).forEach((clientId) => {
const viewer = viewers[clientId]
const clientActive = activeClientIds.has(clientId)
const userActive = viewer?.userId ? activeUserIds.has(viewer.userId) : false
if (!clientActive && !userActive) {
delete viewers[clientId]
hasChanges = true
}
})
if (Object.keys(viewers).length === 0)
delete this.nodePanelPresence[nodeId]
})
if (hasChanges)
this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot())
}
init = (appId: string, reactFlowStore: any): void => {
if (!reactFlowStore) {
console.warn('CollaborationManager.init called without reactFlowStore, deferring to connect()')
return
}
this.connect(appId, reactFlowStore)
}
setNodes = (oldNodes: Node[], newNodes: Node[]): void => {
if (!this.doc) return
// Don't track operations during undo/redo to prevent loops
if (this.isUndoRedoInProgress) {
console.log('Skipping setNodes during undo/redo')
return
}
console.log('Setting nodes with tracking')
this.syncNodes(oldNodes, newNodes)
this.doc.commit()
}
setEdges = (oldEdges: Edge[], newEdges: Edge[]): void => {
if (!this.doc) return
// Don't track operations during undo/redo to prevent loops
if (this.isUndoRedoInProgress) {
console.log('Skipping setEdges during undo/redo')
return
}
console.log('Setting edges with tracking')
this.syncEdges(oldEdges, newEdges)
this.doc.commit()
}
destroy = (): void => {
this.disconnect()
}
async connect(appId: string, reactFlowStore?: any): Promise<string> {
const connectionId = Math.random().toString(36).substring(2, 11)
this.activeConnections.add(connectionId)
if (this.currentAppId === appId && this.doc) {
// Already connected to the same app, only update store if provided and we don't have one
if (reactFlowStore && !this.reactFlowStore)
this.reactFlowStore = reactFlowStore
return connectionId
}
// Only disconnect if switching to a different app
if (this.currentAppId && this.currentAppId !== appId)
this.forceDisconnect()
this.currentAppId = appId
// Only set store if provided
if (reactFlowStore)
this.reactFlowStore = reactFlowStore
const socket = webSocketClient.connect(appId)
// Setup event listeners BEFORE any other operations
this.setupSocketEventListeners(socket)
this.doc = new LoroDoc()
this.nodesMap = this.doc.getMap('nodes')
this.edgesMap = this.doc.getMap('edges')
// Initialize UndoManager for collaborative undo/redo
this.undoManager = new UndoManager(this.doc, {
maxUndoSteps: 100,
mergeInterval: 500, // Merge operations within 500ms
excludeOriginPrefixes: [], // Don't exclude anything - let UndoManager track all local operations
onPush: (isUndo, range, event) => {
console.log('UndoManager onPush:', { isUndo, range, event })
// Store current selection state when an operation is pushed
const selectedNode = this.reactFlowStore?.getState().getNodes().find((n: Node) => n.data?.selected)
// Emit event to update UI button states when new operation is pushed
setTimeout(() => {
this.eventEmitter.emit('undoRedoStateChange', {
canUndo: this.undoManager?.canUndo() || false,
canRedo: this.undoManager?.canRedo() || false,
})
}, 0)
return {
value: {
selectedNodeId: selectedNode?.id || null,
timestamp: Date.now(),
},
cursors: [],
}
},
onPop: (isUndo, value, counterRange) => {
console.log('UndoManager onPop:', { isUndo, value, counterRange })
// Restore selection state when undoing/redoing
if (value?.value && typeof value.value === 'object' && 'selectedNodeId' in value.value && this.reactFlowStore) {
const selectedNodeId = (value.value as any).selectedNodeId
if (selectedNodeId) {
const { setNodes } = this.reactFlowStore.getState()
const nodes = this.reactFlowStore.getState().getNodes()
const newNodes = nodes.map((n: Node) => ({
...n,
data: {
...n.data,
selected: n.id === selectedNodeId,
},
}))
setNodes(newNodes)
}
}
},
})
this.provider = new CRDTProvider(socket, this.doc)
this.setupSubscriptions()
// Force user_connect if already connected
if (socket.connected)
socket.emit('user_connect', { workflow_id: appId })
return connectionId
}
disconnect = (connectionId?: string): void => {
if (connectionId)
this.activeConnections.delete(connectionId)
// Only disconnect when no more connections
if (this.activeConnections.size === 0)
this.forceDisconnect()
}
private forceDisconnect = (): void => {
if (this.currentAppId)
webSocketClient.disconnect(this.currentAppId)
this.provider?.destroy()
this.undoManager = null
this.doc = null
this.provider = null
this.nodesMap = null
this.edgesMap = null
this.currentAppId = null
this.reactFlowStore = null
this.cursors = {}
this.nodePanelPresence = {}
this.isUndoRedoInProgress = false
// Only reset leader status when actually disconnecting
const wasLeader = this.isLeader
this.isLeader = false
this.leaderId = null
if (wasLeader)
this.eventEmitter.emit('leaderChange', false)
this.activeConnections.clear()
this.eventEmitter.removeAllListeners()
}
isConnected(): boolean {
return this.currentAppId ? webSocketClient.isConnected(this.currentAppId) : false
}
getNodes(): Node[] {
return this.nodesMap ? Array.from(this.nodesMap.values()) : []
}
getEdges(): Edge[] {
return this.edgesMap ? Array.from(this.edgesMap.values()) : []
}
emitCursorMove(position: CursorPosition): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (socket) {
socket.emit('collaboration_event', {
type: 'mouseMove',
userId: socket.id,
data: { x: position.x, y: position.y },
timestamp: Date.now(),
})
}
}
emitSyncRequest(): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (socket) {
console.log('Emitting sync request to leader')
socket.emit('collaboration_event', {
type: 'syncRequest',
data: { timestamp: Date.now() },
timestamp: Date.now(),
})
}
}
emitWorkflowUpdate(appId: string): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (socket) {
console.log('Emitting Workflow update event')
socket.emit('collaboration_event', {
type: 'workflowUpdate',
data: { appId, timestamp: Date.now() },
timestamp: Date.now(),
})
}
}
emitNodePanelPresence(nodeId: string, isOpen: boolean, user: NodePanelPresenceUser): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (!socket || !nodeId || !user?.userId) return
const payload: NodePanelPresenceEventData = {
nodeId,
action: isOpen ? 'open' : 'close',
user,
clientId: socket.id as string,
timestamp: Date.now(),
}
socket.emit('collaboration_event', {
type: 'nodePanelPresence',
data: payload,
timestamp: payload.timestamp,
})
this.applyNodePanelPresenceUpdate(payload)
}
onSyncRequest(callback: () => void): () => void {
return this.eventEmitter.on('syncRequest', callback)
}
onStateChange(callback: (state: Partial<CollaborationState>) => void): () => void {
return this.eventEmitter.on('stateChange', callback)
}
onCursorUpdate(callback: (cursors: Record<string, CursorPosition>) => void): () => void {
return this.eventEmitter.on('cursors', callback)
}
onOnlineUsersUpdate(callback: (users: OnlineUser[]) => void): () => void {
return this.eventEmitter.on('onlineUsers', callback)
}
onWorkflowUpdate(callback: (update: { appId: string; timestamp: number }) => void): () => void {
return this.eventEmitter.on('workflowUpdate', callback)
}
onVarsAndFeaturesUpdate(callback: (update: any) => void): () => void {
return this.eventEmitter.on('varsAndFeaturesUpdate', callback)
}
onAppStateUpdate(callback: (update: any) => void): () => void {
return this.eventEmitter.on('appStateUpdate', callback)
}
onMcpServerUpdate(callback: (update: any) => void): () => void {
return this.eventEmitter.on('mcpServerUpdate', callback)
}
onNodePanelPresenceUpdate(callback: (presence: NodePanelPresenceMap) => void): () => void {
const off = this.eventEmitter.on('nodePanelPresence', callback)
callback(this.getNodePanelPresenceSnapshot())
return off
}
onLeaderChange(callback: (isLeader: boolean) => void): () => void {
return this.eventEmitter.on('leaderChange', callback)
}
onCommentsUpdate(callback: (update: { appId: string; timestamp: number }) => void): () => void {
return this.eventEmitter.on('commentsUpdate', callback)
}
emitCommentsUpdate(appId: string): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (socket) {
console.log('Emitting Comments update event')
socket.emit('collaboration_event', {
type: 'commentsUpdate',
data: { appId, timestamp: Date.now() },
timestamp: Date.now(),
})
}
}
onUndoRedoStateChange(callback: (state: { canUndo: boolean; canRedo: boolean }) => void): () => void {
return this.eventEmitter.on('undoRedoStateChange', callback)
}
getLeaderId(): string | null {
return this.leaderId
}
getIsLeader(): boolean {
return this.isLeader
}
// Collaborative undo/redo methods
undo(): boolean {
if (!this.undoManager) {
console.log('UndoManager not initialized')
return false
}
const canUndo = this.undoManager.canUndo()
console.log('Can undo:', canUndo)
if (canUndo) {
this.isUndoRedoInProgress = true
const result = this.undoManager.undo()
// After undo, manually update React state from CRDT without triggering collaboration
if (result && this.reactFlowStore) {
requestAnimationFrame(() => {
// Get ReactFlow's native setters, not the collaborative ones
const state = this.reactFlowStore.getState()
const updatedNodes = Array.from(this.nodesMap.values())
const updatedEdges = Array.from(this.edgesMap.values())
console.log('Manually updating React state after undo')
// Call ReactFlow's native setters directly to avoid triggering collaboration
state.setNodes(updatedNodes)
state.setEdges(updatedEdges)
this.isUndoRedoInProgress = false
// Emit event to update UI button states
this.eventEmitter.emit('undoRedoStateChange', {
canUndo: this.undoManager?.canUndo() || false,
canRedo: this.undoManager?.canRedo() || false,
})
})
}
else {
this.isUndoRedoInProgress = false
}
console.log('Undo result:', result)
return result
}
return false
}
redo(): boolean {
if (!this.undoManager) {
console.log('RedoManager not initialized')
return false
}
const canRedo = this.undoManager.canRedo()
console.log('Can redo:', canRedo)
if (canRedo) {
this.isUndoRedoInProgress = true
const result = this.undoManager.redo()
// After redo, manually update React state from CRDT without triggering collaboration
if (result && this.reactFlowStore) {
requestAnimationFrame(() => {
// Get ReactFlow's native setters, not the collaborative ones
const state = this.reactFlowStore.getState()
const updatedNodes = Array.from(this.nodesMap.values())
const updatedEdges = Array.from(this.edgesMap.values())
console.log('Manually updating React state after redo')
// Call ReactFlow's native setters directly to avoid triggering collaboration
state.setNodes(updatedNodes)
state.setEdges(updatedEdges)
this.isUndoRedoInProgress = false
// Emit event to update UI button states
this.eventEmitter.emit('undoRedoStateChange', {
canUndo: this.undoManager?.canUndo() || false,
canRedo: this.undoManager?.canRedo() || false,
})
})
}
else {
this.isUndoRedoInProgress = false
}
console.log('Redo result:', result)
return result
}
return false
}
canUndo(): boolean {
if (!this.undoManager) return false
return this.undoManager.canUndo()
}
canRedo(): boolean {
if (!this.undoManager) return false
return this.undoManager.canRedo()
}
clearUndoStack(): void {
if (!this.undoManager) return
this.undoManager.clear()
}
debugLeaderStatus(): void {
console.log('=== Leader Status Debug ===')
console.log('Current leader status:', this.isLeader)
console.log('Current leader ID:', this.leaderId)
console.log('Active connections:', this.activeConnections.size)
console.log('Connected:', this.isConnected())
console.log('Current app ID:', this.currentAppId)
console.log('Has ReactFlow store:', !!this.reactFlowStore)
console.log('========================')
}
private syncNodes(oldNodes: Node[], newNodes: Node[]): void {
if (!this.nodesMap || !this.doc) return
const oldNodesMap = new Map(oldNodes.map(node => [node.id, node]))
const newNodesMap = new Map(newNodes.map(node => [node.id, node]))
const syncDataAllowList = new Set(['_children'])
const shouldSyncDataKey = (key: string) => (syncDataAllowList.has(key) || !key.startsWith('_')) && key !== 'selected'
// Delete removed nodes
oldNodes.forEach((oldNode) => {
if (!newNodesMap.has(oldNode.id))
this.nodesMap.delete(oldNode.id)
})
// Add or update nodes with fine-grained sync for data properties
const copyOptionalNodeProps = (source: Node, target: any) => {
const optionalProps: Array<keyof Node | keyof any> = [
'parentId',
'positionAbsolute',
'extent',
'zIndex',
'draggable',
'selectable',
'dragHandle',
'dragging',
'connectable',
'expandParent',
'focusable',
'hidden',
'style',
'className',
'ariaLabel',
'markerStart',
'markerEnd',
'resizing',
'deletable',
]
optionalProps.forEach((prop) => {
const value = (source as any)[prop]
if (value === undefined) {
if (prop in target)
delete target[prop]
return
}
if (value !== null && typeof value === 'object')
target[prop as string] = JSON.parse(JSON.stringify(value))
else
target[prop as string] = value
})
}
newNodes.forEach((newNode) => {
const oldNode = oldNodesMap.get(newNode.id)
if (!oldNode) {
// New node - create as nested structure
const nodeData: any = {
id: newNode.id,
type: newNode.type,
position: { ...newNode.position },
width: newNode.width,
height: newNode.height,
sourcePosition: newNode.sourcePosition,
targetPosition: newNode.targetPosition,
data: {},
}
copyOptionalNodeProps(newNode, nodeData)
// Clone data properties, excluding private ones
Object.entries(newNode.data).forEach(([key, value]) => {
if (shouldSyncDataKey(key) && value !== undefined)
nodeData.data[key] = JSON.parse(JSON.stringify(value))
})
this.nodesMap.set(newNode.id, nodeData)
}
else {
// Get existing node from CRDT
const existingNode = this.nodesMap.get(newNode.id)
if (existingNode) {
// Create a deep copy to modify
const updatedNode = JSON.parse(JSON.stringify(existingNode))
// Update position only if changed
if (oldNode.position.x !== newNode.position.x || oldNode.position.y !== newNode.position.y)
updatedNode.position = { ...newNode.position }
// Update dimensions only if changed
if (oldNode.width !== newNode.width)
updatedNode.width = newNode.width
if (oldNode.height !== newNode.height)
updatedNode.height = newNode.height
// Ensure optional node props stay in sync
copyOptionalNodeProps(newNode, updatedNode)
// Ensure data object exists
if (!updatedNode.data)
updatedNode.data = {}
// Fine-grained update of data properties
const oldData = oldNode.data || {}
const newData = newNode.data || {}
// Only update changed properties in data
Object.entries(newData).forEach(([key, value]) => {
if (shouldSyncDataKey(key)) {
const oldValue = (oldData as any)[key]
if (!isEqual(oldValue, value))
updatedNode.data[key] = JSON.parse(JSON.stringify(value))
}
})
// Remove deleted properties from data
Object.keys(oldData).forEach((key) => {
if (shouldSyncDataKey(key) && !(key in newData))
delete updatedNode.data[key]
})
// Only update in CRDT if something actually changed
if (!isEqual(existingNode, updatedNode))
this.nodesMap.set(newNode.id, updatedNode)
}
else {
// Node exists locally but not in CRDT yet
const nodeData: any = {
id: newNode.id,
type: newNode.type,
position: { ...newNode.position },
width: newNode.width,
height: newNode.height,
sourcePosition: newNode.sourcePosition,
targetPosition: newNode.targetPosition,
data: {},
}
copyOptionalNodeProps(newNode, nodeData)
Object.entries(newNode.data).forEach(([key, value]) => {
if (shouldSyncDataKey(key) && value !== undefined)
nodeData.data[key] = JSON.parse(JSON.stringify(value))
})
this.nodesMap.set(newNode.id, nodeData)
}
}
})
}
private syncEdges(oldEdges: Edge[], newEdges: Edge[]): void {
if (!this.edgesMap) return
const oldEdgesMap = new Map(oldEdges.map(edge => [edge.id, edge]))
const newEdgesMap = new Map(newEdges.map(edge => [edge.id, edge]))
oldEdges.forEach((oldEdge) => {
if (!newEdgesMap.has(oldEdge.id))
this.edgesMap.delete(oldEdge.id)
})
newEdges.forEach((newEdge) => {
const oldEdge = oldEdgesMap.get(newEdge.id)
if (!oldEdge) {
const clonedEdge = JSON.parse(JSON.stringify(newEdge))
this.edgesMap.set(newEdge.id, clonedEdge)
}
else if (!isEqual(oldEdge, newEdge)) {
const clonedEdge = JSON.parse(JSON.stringify(newEdge))
this.edgesMap.set(newEdge.id, clonedEdge)
}
})
}
private setupSubscriptions(): void {
this.nodesMap?.subscribe((event: any) => {
console.log('nodesMap subscription event:', event)
if (event.by === 'import' && this.reactFlowStore) {
// Don't update React nodes during undo/redo to prevent loops
if (this.isUndoRedoInProgress) {
console.log('Skipping nodes subscription update during undo/redo')
return
}
requestAnimationFrame(() => {
// Get ReactFlow's native setters, not the collaborative ones
const state = this.reactFlowStore.getState()
const previousNodes: Node[] = state.getNodes()
const selectedIds = new Set(
previousNodes
.filter(node => node.data?.selected)
.map(node => node.id),
)
const updatedNodes = Array
.from(this.nodesMap.values())
.map((node: Node) => {
const clonedNode: Node = {
...node,
data: {
...(node.data || {}),
},
}
if (selectedIds.has(clonedNode.id))
clonedNode.data.selected = true
return clonedNode
})
console.log('Updating React nodes from subscription')
// Call ReactFlow's native setter directly to avoid triggering collaboration
state.setNodes(updatedNodes)
})
}
})
this.edgesMap?.subscribe((event: any) => {
console.log('edgesMap subscription event:', event)
if (event.by === 'import' && this.reactFlowStore) {
// Don't update React edges during undo/redo to prevent loops
if (this.isUndoRedoInProgress) {
console.log('Skipping edges subscription update during undo/redo')
return
}
requestAnimationFrame(() => {
// Get ReactFlow's native setters, not the collaborative ones
const state = this.reactFlowStore.getState()
const updatedEdges = Array.from(this.edgesMap.values())
console.log('Updating React edges from subscription')
// Call ReactFlow's native setter directly to avoid triggering collaboration
state.setEdges(updatedEdges)
})
}
})
}
private setupSocketEventListeners(socket: any): void {
console.log('Setting up socket event listeners for collaboration')
socket.on('collaboration_update', (update: any) => {
if (update.type === 'mouseMove') {
// Update cursor state for this user
this.cursors[update.userId] = {
x: update.data.x,
y: update.data.y,
userId: update.userId,
timestamp: update.timestamp,
}
this.eventEmitter.emit('cursors', { ...this.cursors })
}
else if (update.type === 'varsAndFeaturesUpdate') {
console.log('Processing varsAndFeaturesUpdate event:', update)
this.eventEmitter.emit('varsAndFeaturesUpdate', update)
}
else if (update.type === 'appStateUpdate') {
console.log('Processing appStateUpdate event:', update)
this.eventEmitter.emit('appStateUpdate', update)
}
else if (update.type === 'mcpServerUpdate') {
console.log('Processing mcpServerUpdate event:', update)
this.eventEmitter.emit('mcpServerUpdate', update)
}
else if (update.type === 'workflowUpdate') {
console.log('Processing workflowUpdate event:', update)
this.eventEmitter.emit('workflowUpdate', update.data)
}
else if (update.type === 'commentsUpdate') {
console.log('Processing commentsUpdate event:', update)
this.eventEmitter.emit('commentsUpdate', update.data)
}
else if (update.type === 'nodePanelPresence') {
console.log('Processing nodePanelPresence event:', update)
this.applyNodePanelPresenceUpdate(update.data as NodePanelPresenceEventData)
}
else if (update.type === 'syncRequest') {
console.log('Received sync request from another user')
// Only process if we are the leader
if (this.isLeader) {
console.log('Leader received sync request, triggering sync')
this.eventEmitter.emit('syncRequest', {})
}
}
})
socket.on('online_users', (data: { users: OnlineUser[]; leader?: string }) => {
try {
if (!data || !Array.isArray(data.users)) {
console.warn('Invalid online_users data structure:', data)
return
}
const onlineUserIds = new Set(data.users.map((user: OnlineUser) => user.user_id))
const onlineClientIds = new Set(
data.users
.map((user: OnlineUser) => user.sid)
.filter((sid): sid is string => typeof sid === 'string' && sid.length > 0),
)
// Remove cursors for offline users
Object.keys(this.cursors).forEach((userId) => {
if (!onlineUserIds.has(userId))
delete this.cursors[userId]
})
this.cleanupNodePanelPresence(onlineClientIds, onlineUserIds)
// Update leader information
if (data.leader && typeof data.leader === 'string')
this.leaderId = data.leader
this.eventEmitter.emit('onlineUsers', data.users)
this.eventEmitter.emit('cursors', { ...this.cursors })
}
catch (error) {
console.error('Error processing online_users update:', error)
}
})
socket.on('status', (data: any) => {
try {
if (!data || typeof data.isLeader !== 'boolean') {
console.warn('Invalid status data:', data)
return
}
const wasLeader = this.isLeader
this.isLeader = data.isLeader
if (wasLeader !== this.isLeader)
this.eventEmitter.emit('leaderChange', this.isLeader)
}
catch (error) {
console.error('Error processing status update:', error)
}
})
socket.on('status', (data: { isLeader: boolean }) => {
if (this.isLeader !== data.isLeader) {
this.isLeader = data.isLeader
console.log(`Collaboration: I am now the ${this.isLeader ? 'Leader' : 'Follower'}.`)
this.eventEmitter.emit('leaderChange', this.isLeader)
}
})
socket.on('status', (data: { isLeader: boolean }) => {
if (this.isLeader !== data.isLeader) {
this.isLeader = data.isLeader
console.log(`Collaboration: I am now the ${this.isLeader ? 'Leader' : 'Follower'}.`)
this.eventEmitter.emit('leaderChange', this.isLeader)
}
})
socket.on('connect', () => {
console.log('WebSocket connected successfully')
this.eventEmitter.emit('stateChange', { isConnected: true })
})
socket.on('disconnect', (reason: string) => {
console.log('WebSocket disconnected:', reason)
this.cursors = {}
this.isLeader = false
this.leaderId = null
this.eventEmitter.emit('stateChange', { isConnected: false })
this.eventEmitter.emit('cursors', {})
})
socket.on('connect_error', (error: any) => {
console.error('WebSocket connection error:', error)
this.eventEmitter.emit('stateChange', { isConnected: false, error: error.message })
})
socket.on('error', (error: any) => {
console.error('WebSocket error:', error)
})
}
}
export const collaborationManager = new CollaborationManager()

View File

@ -0,0 +1,36 @@
import type { LoroDoc } from 'loro-crdt'
import type { Socket } from 'socket.io-client'
export class CRDTProvider {
private doc: LoroDoc
private socket: Socket
constructor(socket: Socket, doc: LoroDoc) {
this.socket = socket
this.doc = doc
this.setupEventListeners()
}
private setupEventListeners(): void {
this.doc.subscribe((event: any) => {
if (event.by === 'local') {
const update = this.doc.export({ mode: 'update' })
this.socket.emit('graph_event', update)
}
})
this.socket.on('graph_update', (updateData: Uint8Array) => {
try {
const data = new Uint8Array(updateData)
this.doc.import(data)
}
catch (error) {
console.error('Error importing graph update:', error)
}
})
}
destroy(): void {
this.socket.off('graph_update')
}
}

View File

@ -0,0 +1,49 @@
export type EventHandler<T = any> = (data: T) => void
export class EventEmitter {
private events: Map<string, Set<EventHandler>> = new Map()
on<T = any>(event: string, handler: EventHandler<T>): () => void {
if (!this.events.has(event))
this.events.set(event, new Set())
this.events.get(event)!.add(handler)
return () => this.off(event, handler)
}
off<T = any>(event: string, handler?: EventHandler<T>): void {
if (!this.events.has(event)) return
const handlers = this.events.get(event)!
if (handler)
handlers.delete(handler)
else
handlers.clear()
if (handlers.size === 0)
this.events.delete(event)
}
emit<T = any>(event: string, data: T): void {
if (!this.events.has(event)) return
const handlers = this.events.get(event)!
handlers.forEach((handler) => {
try {
handler(data)
}
catch (error) {
console.error(`Error in event handler for ${event}:`, error)
}
})
}
removeAllListeners(): void {
this.events.clear()
}
getListenerCount(event: string): number {
return this.events.get(event)?.size || 0
}
}

View File

@ -0,0 +1,125 @@
import type { Socket } from 'socket.io-client'
import { io } from 'socket.io-client'
import type { DebugInfo, WebSocketConfig } from '../types/websocket'
export class WebSocketClient {
private connections: Map<string, Socket> = new Map()
private connecting: Set<string> = new Set()
private config: WebSocketConfig
constructor(config: WebSocketConfig = {}) {
const inferUrl = () => {
if (typeof window === 'undefined')
return 'ws://localhost:5001'
const scheme = window.location.protocol === 'https:' ? 'wss:' : 'ws:'
return `${scheme}//${window.location.host}`
}
this.config = {
url: config.url || process.env.NEXT_PUBLIC_SOCKET_URL || inferUrl(),
transports: config.transports || ['websocket'],
withCredentials: config.withCredentials !== false,
...config,
}
}
connect(appId: string): Socket {
const existingSocket = this.connections.get(appId)
if (existingSocket?.connected)
return existingSocket
if (this.connecting.has(appId)) {
const pendingSocket = this.connections.get(appId)
if (pendingSocket)
return pendingSocket
}
if (existingSocket && !existingSocket.connected) {
existingSocket.disconnect()
this.connections.delete(appId)
}
this.connecting.add(appId)
const authToken = localStorage.getItem('console_token')
const socket = io(this.config.url!, {
path: '/socket.io',
transports: this.config.transports,
auth: { token: authToken },
withCredentials: this.config.withCredentials,
})
this.connections.set(appId, socket)
this.setupBaseEventListeners(socket, appId)
return socket
}
disconnect(appId?: string): void {
if (appId) {
const socket = this.connections.get(appId)
if (socket) {
socket.disconnect()
this.connections.delete(appId)
this.connecting.delete(appId)
}
}
else {
this.connections.forEach(socket => socket.disconnect())
this.connections.clear()
this.connecting.clear()
}
}
getSocket(appId: string): Socket | null {
return this.connections.get(appId) || null
}
isConnected(appId: string): boolean {
return this.connections.get(appId)?.connected || false
}
getConnectedApps(): string[] {
const connectedApps: string[] = []
this.connections.forEach((socket, appId) => {
if (socket.connected)
connectedApps.push(appId)
})
return connectedApps
}
getDebugInfo(): DebugInfo {
const info: DebugInfo = {}
this.connections.forEach((socket, appId) => {
info[appId] = {
connected: socket.connected,
connecting: this.connecting.has(appId),
socketId: socket.id,
}
})
return info
}
private setupBaseEventListeners(socket: Socket, appId: string): void {
socket.on('connect', () => {
this.connecting.delete(appId)
socket.emit('user_connect', { workflow_id: appId })
})
socket.on('disconnect', () => {
this.connecting.delete(appId)
})
socket.on('connect_error', () => {
this.connecting.delete(appId)
})
}
}
export const webSocketClient = new WebSocketClient()
export const fetchAppsOnlineUsers = async (appIds: string[]) => {
const response = await fetch(`/api/online-users?${new URLSearchParams({
app_ids: appIds.join(','),
})}`)
return response.json()
}

View File

@ -0,0 +1,92 @@
import { useEffect, useRef, useState } from 'react'
import type { ReactFlowInstance } from 'reactflow'
import { collaborationManager } from '../core/collaboration-manager'
import { CursorService } from '../services/cursor-service'
import type { CollaborationState } from '../types/collaboration'
export function useCollaboration(appId: string, reactFlowStore?: any) {
const [state, setState] = useState<Partial<CollaborationState & { isLeader: boolean }>>({
isConnected: false,
onlineUsers: [],
cursors: {},
nodePanelPresence: {},
isLeader: false,
})
const cursorServiceRef = useRef<CursorService | null>(null)
useEffect(() => {
if (!appId) return
let connectionId: string | null = null
if (!cursorServiceRef.current)
cursorServiceRef.current = new CursorService()
const initCollaboration = async () => {
connectionId = await collaborationManager.connect(appId, reactFlowStore)
setState((prev: any) => ({ ...prev, appId, isConnected: collaborationManager.isConnected() }))
}
initCollaboration()
const unsubscribeStateChange = collaborationManager.onStateChange((newState: any) => {
console.log('Collaboration state change:', newState)
setState((prev: any) => ({ ...prev, ...newState }))
})
const unsubscribeCursors = collaborationManager.onCursorUpdate((cursors: any) => {
setState((prev: any) => ({ ...prev, cursors }))
})
const unsubscribeUsers = collaborationManager.onOnlineUsersUpdate((users: any) => {
console.log('Online users update:', users)
setState((prev: any) => ({ ...prev, onlineUsers: users }))
})
const unsubscribeNodePanelPresence = collaborationManager.onNodePanelPresenceUpdate((presence) => {
setState((prev: any) => ({ ...prev, nodePanelPresence: presence }))
})
const unsubscribeLeaderChange = collaborationManager.onLeaderChange((isLeader: boolean) => {
console.log('Leader status changed:', isLeader)
setState((prev: any) => ({ ...prev, isLeader }))
})
return () => {
unsubscribeStateChange()
unsubscribeCursors()
unsubscribeUsers()
unsubscribeNodePanelPresence()
unsubscribeLeaderChange()
cursorServiceRef.current?.stopTracking()
if (connectionId)
collaborationManager.disconnect(connectionId)
}
}, [appId, reactFlowStore])
const startCursorTracking = (containerRef: React.RefObject<HTMLElement>, reactFlowInstance?: ReactFlowInstance) => {
if (cursorServiceRef.current) {
cursorServiceRef.current.startTracking(containerRef, (position) => {
collaborationManager.emitCursorMove(position)
}, reactFlowInstance)
}
}
const stopCursorTracking = () => {
cursorServiceRef.current?.stopTracking()
}
const result = {
isConnected: state.isConnected || false,
onlineUsers: state.onlineUsers || [],
cursors: state.cursors || {},
nodePanelPresence: state.nodePanelPresence || {},
isLeader: state.isLeader || false,
leaderId: collaborationManager.getLeaderId(),
startCursorTracking,
stopCursorTracking,
}
return result
}

View File

@ -0,0 +1,5 @@
export { collaborationManager } from './core/collaboration-manager'
export { webSocketClient, fetchAppsOnlineUsers } from './core/websocket-manager'
export { CursorService } from './services/cursor-service'
export { useCollaboration } from './hooks/use-collaboration'
export * from './types'

View File

@ -0,0 +1,88 @@
import type { RefObject } from 'react'
import type { CursorPosition } from '../types/collaboration'
import type { ReactFlowInstance } from 'reactflow'
const CURSOR_MIN_MOVE_DISTANCE = 10
const CURSOR_THROTTLE_MS = 500
export class CursorService {
private containerRef: RefObject<HTMLElement> | null = null
private reactFlowInstance: ReactFlowInstance | null = null
private isTracking = false
private onCursorUpdate: ((cursors: Record<string, CursorPosition>) => void) | null = null
private onEmitPosition: ((position: CursorPosition) => void) | null = null
private lastEmitTime = 0
private lastPosition: { x: number; y: number } | null = null
startTracking(
containerRef: RefObject<HTMLElement>,
onEmitPosition: (position: CursorPosition) => void,
reactFlowInstance?: ReactFlowInstance,
): void {
if (this.isTracking) this.stopTracking()
this.containerRef = containerRef
this.onEmitPosition = onEmitPosition
this.reactFlowInstance = reactFlowInstance || null
this.isTracking = true
if (containerRef.current)
containerRef.current.addEventListener('mousemove', this.handleMouseMove)
}
stopTracking(): void {
if (this.containerRef?.current)
this.containerRef.current.removeEventListener('mousemove', this.handleMouseMove)
this.containerRef = null
this.reactFlowInstance = null
this.onEmitPosition = null
this.isTracking = false
this.lastPosition = null
}
setCursorUpdateHandler(handler: (cursors: Record<string, CursorPosition>) => void): void {
this.onCursorUpdate = handler
}
updateCursors(cursors: Record<string, CursorPosition>): void {
if (this.onCursorUpdate)
this.onCursorUpdate(cursors)
}
private handleMouseMove = (event: MouseEvent): void => {
if (!this.containerRef?.current || !this.onEmitPosition) return
const rect = this.containerRef.current.getBoundingClientRect()
let x = event.clientX - rect.left
let y = event.clientY - rect.top
// Transform coordinates to ReactFlow world coordinates if ReactFlow instance is available
if (this.reactFlowInstance) {
const viewport = this.reactFlowInstance.getViewport()
// Convert screen coordinates to world coordinates
// World coordinates = (screen coordinates - viewport translation) / zoom
x = (x - viewport.x) / viewport.zoom
y = (y - viewport.y) / viewport.zoom
}
// Always emit cursor position (remove boundary check since world coordinates can be negative)
const now = Date.now()
const timeThrottled = now - this.lastEmitTime > CURSOR_THROTTLE_MS
const minDistance = CURSOR_MIN_MOVE_DISTANCE / (this.reactFlowInstance?.getZoom() || 1)
const distanceThrottled = !this.lastPosition
|| (Math.abs(x - this.lastPosition.x) > minDistance)
|| (Math.abs(y - this.lastPosition.y) > minDistance)
if (timeThrottled && distanceThrottled) {
this.lastPosition = { x, y }
this.lastEmitTime = now
this.onEmitPosition({
x,
y,
userId: '',
timestamp: now,
})
}
}
}

View File

@ -0,0 +1,57 @@
import type { Edge, Node } from '../../types'
export type OnlineUser = {
user_id: string
username: string
avatar: string
sid: string
}
export type WorkflowOnlineUsers = {
workflow_id: string
users: OnlineUser[]
}
export type OnlineUserListResponse = {
data: WorkflowOnlineUsers[]
}
export type CursorPosition = {
x: number
y: number
userId: string
timestamp: number
}
export type NodePanelPresenceUser = {
userId: string
username: string
avatar?: string | null
}
export type NodePanelPresenceInfo = NodePanelPresenceUser & {
clientId: string
timestamp: number
}
export type NodePanelPresenceMap = Record<string, Record<string, NodePanelPresenceInfo>>
export type CollaborationState = {
appId: string
isConnected: boolean
onlineUsers: OnlineUser[]
cursors: Record<string, CursorPosition>
nodePanelPresence: NodePanelPresenceMap
}
export type GraphSyncData = {
nodes: Node[]
edges: Edge[]
}
export type CollaborationUpdate = {
type: 'mouseMove' | 'graphUpdate' | 'userJoin' | 'userLeave'
userId: string
data: any
timestamp: number
}

View File

@ -0,0 +1,38 @@
export type CollaborationEvent = {
type: string
data: any
timestamp: number
}
export type GraphUpdateEvent = {
type: 'graph_update'
data: Uint8Array
} & CollaborationEvent
export type CursorMoveEvent = {
type: 'cursor_move'
data: {
x: number
y: number
userId: string
}
} & CollaborationEvent
export type UserConnectEvent = {
type: 'user_connect'
data: {
workflow_id: string
}
} & CollaborationEvent
export type OnlineUsersEvent = {
type: 'online_users'
data: {
users: Array<{
user_id: string
username: string
avatar: string
sid: string
}>
}
} & CollaborationEvent

View File

@ -0,0 +1,3 @@
export * from './websocket'
export * from './collaboration'
export * from './events'

View File

@ -0,0 +1,16 @@
export type WebSocketConfig = {
url?: string
token?: string
transports?: string[]
withCredentials?: boolean
}
export type ConnectionInfo = {
connected: boolean
connecting: boolean
socketId?: string
}
export type DebugInfo = {
[appId: string]: ConnectionInfo
}

View File

@ -0,0 +1,12 @@
/**
* Generate a consistent color for a user based on their ID
* Used for cursor colors and avatar backgrounds
*/
export const getUserColor = (id: string): string => {
const colors = ['#155AEF', '#0BA5EC', '#444CE7', '#7839EE', '#4CA30D', '#0E9384', '#DD2590', '#FF4405', '#D92D20', '#F79009', '#828DAD']
const hash = id.split('').reduce((a, b) => {
a = ((a << 5) - a) + b.charCodeAt(0)
return a & a
}, 0)
return colors[Math.abs(hash) % colors.length]
}

View File

@ -0,0 +1,31 @@
import { useEventListener } from 'ahooks'
import { useWorkflowStore } from './store'
import { useWorkflowComment } from './hooks/use-workflow-comment'
const CommentManager = () => {
const workflowStore = useWorkflowStore()
const { handleCreateComment } = useWorkflowComment()
useEventListener('click', (e) => {
const { controlMode, mousePosition } = workflowStore.getState()
if (controlMode === 'comment') {
const target = e.target as HTMLElement
const isInDropdown = target.closest('[data-mention-dropdown]')
const isInCommentInput = target.closest('[data-comment-input]')
const isOnCanvasPane = target.closest('.react-flow__pane')
// Only when clicking on the React Flow canvas pane (background),
// and not inside comment input or its dropdown
if (!isInDropdown && !isInCommentInput && isOnCanvasPane) {
e.preventDefault()
e.stopPropagation()
handleCreateComment(mousePosition)
}
}
})
return null
}
export default CommentManager

View File

@ -0,0 +1,239 @@
'use client'
import type { FC, PointerEvent as ReactPointerEvent } from 'react'
import { memo, useCallback, useMemo, useRef, useState } from 'react'
import { useReactFlow, useViewport } from 'reactflow'
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
import CommentPreview from './comment-preview'
import type { WorkflowCommentList } from '@/service/workflow-comment'
type CommentIconProps = {
comment: WorkflowCommentList
onClick: () => void
isActive?: boolean
onPositionUpdate?: (position: { x: number; y: number }) => void
}
export const CommentIcon: FC<CommentIconProps> = memo(({ comment, onClick, isActive = false, onPositionUpdate }) => {
const { flowToScreenPosition, screenToFlowPosition } = useReactFlow()
const viewport = useViewport()
const [showPreview, setShowPreview] = useState(false)
const [dragPosition, setDragPosition] = useState<{ x: number; y: number } | null>(null)
const [isDragging, setIsDragging] = useState(false)
const dragStateRef = useRef<{
offsetX: number
offsetY: number
startX: number
startY: number
hasMoved: boolean
} | null>(null)
const screenPosition = useMemo(() => {
return flowToScreenPosition({
x: comment.position_x,
y: comment.position_y,
})
}, [comment.position_x, comment.position_y, viewport.x, viewport.y, viewport.zoom, flowToScreenPosition])
const effectivePosition = dragPosition ?? screenPosition
const handlePointerDown = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
if (event.button !== 0)
return
event.stopPropagation()
event.preventDefault()
dragStateRef.current = {
offsetX: event.clientX - screenPosition.x,
offsetY: event.clientY - screenPosition.y,
startX: event.clientX,
startY: event.clientY,
hasMoved: false,
}
setDragPosition(screenPosition)
setIsDragging(false)
if (event.currentTarget.dataset.role !== 'comment-preview')
setShowPreview(false)
if (event.currentTarget.setPointerCapture)
event.currentTarget.setPointerCapture(event.pointerId)
}, [screenPosition])
const handlePointerMove = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
const dragState = dragStateRef.current
if (!dragState)
return
event.stopPropagation()
event.preventDefault()
const nextX = event.clientX - dragState.offsetX
const nextY = event.clientY - dragState.offsetY
if (!dragState.hasMoved) {
const distance = Math.hypot(event.clientX - dragState.startX, event.clientY - dragState.startY)
if (distance > 4) {
dragState.hasMoved = true
setIsDragging(true)
}
}
setDragPosition({ x: nextX, y: nextY })
}, [])
const finishDrag = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
const dragState = dragStateRef.current
if (!dragState)
return false
if (event.currentTarget.hasPointerCapture?.(event.pointerId))
event.currentTarget.releasePointerCapture(event.pointerId)
dragStateRef.current = null
setDragPosition(null)
setIsDragging(false)
return dragState.hasMoved
}, [])
const handlePointerUp = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
event.stopPropagation()
event.preventDefault()
const finalScreenPosition = dragPosition ?? screenPosition
const didDrag = finishDrag(event)
setShowPreview(false)
if (didDrag) {
if (onPositionUpdate) {
const flowPosition = screenToFlowPosition({
x: finalScreenPosition.x,
y: finalScreenPosition.y,
})
onPositionUpdate(flowPosition)
}
}
else if (!isActive) {
onClick()
}
}, [dragPosition, finishDrag, isActive, onClick, onPositionUpdate, screenPosition, screenToFlowPosition])
const handlePointerCancel = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
event.stopPropagation()
event.preventDefault()
finishDrag(event)
}, [finishDrag])
const handleMouseEnter = useCallback(() => {
if (isActive || isDragging)
return
setShowPreview(true)
}, [isActive, isDragging])
const handleMouseLeave = useCallback(() => {
setShowPreview(false)
}, [])
const participants = useMemo(() => {
const list = comment.participants ?? []
const author = comment.created_by_account
if (!author)
return [...list]
const rest = list.filter(user => user.id !== author.id)
return [author, ...rest]
}, [comment.created_by_account, comment.participants])
// Calculate dynamic width based on number of participants
const participantCount = participants.length
const maxVisible = Math.min(3, participantCount)
const showCount = participantCount > 3
const avatarSize = 24
const avatarSpacing = 4 // -space-x-1 is about 4px overlap
// Width calculation: first avatar + (additional avatars * (size - spacing)) + padding
const dynamicWidth = Math.max(40, // minimum width
8 + avatarSize + Math.max(0, (showCount ? 2 : maxVisible - 1)) * (avatarSize - avatarSpacing) + 8,
)
const pointerEventHandlers = useMemo(() => ({
onPointerDown: handlePointerDown,
onPointerMove: handlePointerMove,
onPointerUp: handlePointerUp,
onPointerCancel: handlePointerCancel,
}), [handlePointerCancel, handlePointerDown, handlePointerMove, handlePointerUp])
return (
<>
<div
className="absolute z-10"
style={{
left: effectivePosition.x,
top: effectivePosition.y,
transform: 'translate(-50%, -50%)',
}}
data-role='comment-marker'
{...pointerEventHandlers}
>
<div
className={isActive ? (isDragging ? 'cursor-grabbing' : '') : isDragging ? 'cursor-grabbing' : 'cursor-pointer'}
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
>
<div
className={'relative h-10 overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full'}
style={{ width: dynamicWidth }}
>
<div className={`absolute inset-[6px] overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full border ${
isActive
? 'border-2 border-primary-500 bg-components-panel-bg'
: 'border-components-panel-border bg-components-panel-bg'
}`}>
<div className="flex h-full w-full items-center justify-center px-1">
<UserAvatarList
users={participants}
maxVisible={3}
size={24}
/>
</div>
</div>
</div>
</div>
</div>
{/* Preview panel */}
{showPreview && !isActive && (
<div
className="absolute z-20"
style={{
left: (dragPosition ?? screenPosition).x - dynamicWidth / 2,
top: (dragPosition ?? screenPosition).y + 20,
transform: 'translateY(-100%)',
}}
data-role='comment-preview'
{...pointerEventHandlers}
onMouseEnter={() => setShowPreview(true)}
onMouseLeave={() => setShowPreview(false)}
>
<CommentPreview comment={comment} onClick={() => {
setShowPreview(false)
onClick()
}} />
</div>
)}
</>
)
}, (prevProps, nextProps) => {
return (
prevProps.comment.id === nextProps.comment.id
&& prevProps.comment.position_x === nextProps.comment.position_x
&& prevProps.comment.position_y === nextProps.comment.position_y
&& prevProps.onClick === nextProps.onClick
&& prevProps.isActive === nextProps.isActive
&& prevProps.onPositionUpdate === nextProps.onPositionUpdate
)
})
CommentIcon.displayName = 'CommentIcon'

View File

@ -0,0 +1,87 @@
import type { FC } from 'react'
import { memo, useCallback, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Avatar from '@/app/components/base/avatar'
import { useAppContext } from '@/context/app-context'
import { MentionInput } from './mention-input'
import cn from '@/utils/classnames'
type CommentInputProps = {
position: { x: number; y: number }
onSubmit: (content: string, mentionedUserIds: string[]) => void
onCancel: () => void
}
export const CommentInput: FC<CommentInputProps> = memo(({ position, onSubmit, onCancel }) => {
const [content, setContent] = useState('')
const { t } = useTranslation()
const { userProfile } = useAppContext()
useEffect(() => {
const handleGlobalKeyDown = (e: KeyboardEvent) => {
if (e.key === 'Escape') {
e.preventDefault()
e.stopPropagation()
onCancel()
}
}
document.addEventListener('keydown', handleGlobalKeyDown, true)
return () => {
document.removeEventListener('keydown', handleGlobalKeyDown, true)
}
}, [onCancel])
const handleMentionSubmit = useCallback((content: string, mentionedUserIds: string[]) => {
onSubmit(content, mentionedUserIds)
setContent('')
}, [onSubmit])
return (
<div
className="absolute z-50 w-96"
style={{
left: position.x,
top: position.y,
}}
data-comment-input
>
<div className="flex items-center gap-3">
<div className="relative shrink-0">
<div className="relative h-8 w-8 overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full bg-primary-500">
<div className="absolute inset-[2px] overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full bg-white">
<div className="flex h-full w-full items-center justify-center">
<div className="h-6 w-6 overflow-hidden rounded-full">
<Avatar
avatar={userProfile.avatar_url}
name={userProfile.name}
size={24}
className="h-full w-full"
/>
</div>
</div>
</div>
</div>
</div>
<div
className={cn(
'relative z-10 flex-1 rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur pb-[4px] shadow-md',
)}
>
<div className='relative px-[9px] pt-[4px]'>
<MentionInput
value={content}
onChange={setContent}
onSubmit={handleMentionSubmit}
placeholder={t('workflow.comments.placeholder.add')}
autoFocus
className="relative"
/>
</div>
</div>
</div>
</div>
)
})
CommentInput.displayName = 'CommentInput'

View File

@ -0,0 +1,52 @@
'use client'
import type { FC } from 'react'
import { memo, useMemo } from 'react'
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
import type { WorkflowCommentList } from '@/service/workflow-comment'
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
type CommentPreviewProps = {
comment: WorkflowCommentList
onClick?: () => void
}
const CommentPreview: FC<CommentPreviewProps> = ({ comment, onClick }) => {
const { formatTimeFromNow } = useFormatTimeFromNow()
const participants = useMemo(() => {
const list = comment.participants ?? []
const author = comment.created_by_account
if (!author)
return [...list]
const rest = list.filter(user => user.id !== author.id)
return [author, ...rest]
}, [comment.created_by_account, comment.participants])
return (
<div
className="w-80 cursor-pointer rounded-br-xl rounded-tl-xl rounded-tr-xl border border-components-panel-border bg-components-panel-bg p-4 shadow-lg transition-colors hover:bg-components-panel-on-panel-item-bg-hover"
onClick={onClick}
>
<div className="mb-3 flex items-center justify-between">
<UserAvatarList
users={participants}
maxVisible={3}
size={24}
/>
</div>
<div className="mb-2 flex items-start">
<div className="flex min-w-0 items-center gap-2">
<div className="system-sm-medium truncate text-text-primary">{comment.created_by_account.name}</div>
<div className="system-2xs-regular shrink-0 text-text-tertiary">
{formatTimeFromNow(comment.updated_at * 1000)}
</div>
</div>
</div>
<div className="system-sm-regular break-words text-text-secondary">{comment.content}</div>
</div>
)
}
export default memo(CommentPreview)

View File

@ -0,0 +1,28 @@
import type { FC } from 'react'
import { memo } from 'react'
import { useStore } from '../store'
import { ControlMode } from '../types'
import { Comment } from '@/app/components/base/icons/src/public/other'
export const CommentCursor: FC = memo(() => {
const controlMode = useStore(s => s.controlMode)
const mousePosition = useStore(s => s.mousePosition)
if (controlMode !== ControlMode.Comment)
return null
return (
<div
className="pointer-events-none absolute z-50 flex h-6 w-6 items-center justify-center"
style={{
left: mousePosition.elementX,
top: mousePosition.elementY,
transform: 'translate(-50%, -50%)',
}}
>
<Comment className="text-text-primary" />
</div>
)
})
CommentCursor.displayName = 'CommentCursor'

Some files were not shown because too many files have changed in this diff Show More