Compare commits

...

143 Commits

Author SHA1 Message Date
1cf788c43b Merge branch 'main' into feat/queue-based-graph-engine 2025-09-17 12:46:08 +08:00
73a7756350 feat(graph_engine): allow to dumps and loads RSC 2025-09-17 12:45:51 +08:00
02d15ebd5a feat(graph_engine): support dumps and loads in GraphExecution 2025-09-16 19:38:10 +08:00
976b3b5e83 Merge branch 'main' into feat/queue-based-graph-engine 2025-09-16 15:21:36 +08:00
b5684f1992 refactor(graph_engine): remove unused parameters from Engine 2025-09-16 14:11:42 +08:00
bd13cf05eb Merge branch 'main' into feat/queue-based-graph-engine 2025-09-16 12:59:26 +08:00
5f263147f9 fix: make mypy happy 2025-09-16 12:51:11 +08:00
b68afdfa64 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-16 12:32:16 +08:00
da87fce751 feat(graph_engine): dump and load ready queue 2025-09-16 04:19:46 +08:00
d5342927d0 chore: change _outputs type to dict[str, object] 2025-09-16 01:53:25 +08:00
754d790c89 [autofix.ci] apply automated fixes (attempt 2/3) 2025-09-15 07:58:44 +00:00
a099a35e51 [autofix.ci] apply automated fixes 2025-09-15 07:56:51 +00:00
2dd893e60d Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-15 15:54:42 +08:00
b8ee1d4697 Merge branch 'main' into feat/queue-based-graph-engine 2025-09-15 12:21:18 +08:00
b4ef1de30f feat(graph_engine): add ready_queue state persistence to GraphRuntimeState
- Add ReadyQueueState TypedDict for type-safe queue serialization
- Add ready_queue attribute to GraphRuntimeState for initializing with pre-existing queue state
- Update GraphEngine to load ready_queue from GraphRuntimeState on initialization
- Implement proper type hints using ReadyQueueState for better type safety
- Add comprehensive tests for ready_queue loading functionality

The ready_queue is read-only after initialization and allows resuming workflow
execution with a pre-populated queue of nodes ready to execute.
2025-09-15 03:05:10 +08:00
0f15a2baca [autofix.ci] apply automated fixes 2025-09-13 20:20:53 +00:00
4cdc19fd05 feat(graph_engine): add abstract layer and dump / load methods for ready queue. 2025-09-14 04:19:24 +08:00
efa5f35277 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-14 01:48:06 +08:00
766fda395b Merge branch 'main' into feat/queue-based-graph-engine 2025-09-13 19:37:52 +08:00
b0e815c3c7 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-13 01:31:17 +08:00
462ba354a4 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-12 00:21:06 +08:00
3c668e4a5c fix: update test assertions for ToolProviderApiEntity validation
- Fixed test_repack_provider_entity_no_dark_icon to use empty string instead of None for icon_dark field
- Updated test_builtin_provider_to_user_provider_no_credentials assertion to match actual implementation behavior where masked_credentials always contains empty strings for schema fields
2025-09-11 16:41:10 +08:00
872cff7bab chore(iteration_node): convert some Any to object
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-11 15:40:12 +08:00
8fb69429f9 feat(graph_engine): support parallel mode in iteration node
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-11 15:37:46 +08:00
85064bd8cf Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-11 15:13:31 +08:00
ba5df3612b fix: tests
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-11 15:13:18 +08:00
a923ab1ab8 fix: type errors
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-11 15:01:16 +08:00
b4c1766932 fix: type errors
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 21:48:05 +08:00
00a1af8506 refactor(graph_engine): use singledispatch in Node
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 20:59:34 +08:00
f56fccee9d fix: workflow knowledge query raise error (#25465) 2025-09-10 13:47:47 +08:00
b6b98a2c8e Merge branch 'feat/dispatch-method' into feat/queue-based-graph-engine 2025-09-10 03:12:59 +08:00
7e69403dda refactor(graph_engine): use singledispatchmethod in event_handler
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 03:12:33 +08:00
9796cede72 fix: add missing type field to node configurations in integration tests
- Added 'type' field to all node data configurations in test files
- Fixed test_code.py: added 'type: code' to all code node configs
- Fixed test_http.py: added 'type: http-request' to all HTTP node configs
- Fixed test_template_transform.py: added 'type: template-transform' to template node config
- Fixed test_tool.py: added 'type: tool' to all tool node configs
- Added setup_code_executor_mock fixture to test_execute_code_scientific_notation

These changes fix the ValueError: 'Node X missing or invalid type information' errors
that were occurring due to changes in the node factory validation requirements.
2025-09-10 02:54:01 +08:00
836ed1f380 refactor(graph_engine): Move ErrorHandler into a single file package
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 02:35:05 +08:00
80f39963f1 chore: add import lint to CI
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 02:32:24 +08:00
9cf2b2b231 fix: type errors
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 02:22:58 +08:00
2a97a69825 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-10 02:03:45 +08:00
f17c71e08a refactor(graph_engine): Move GraphStateManager to single file package.
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 01:55:30 +08:00
d52621fce3 refactor(graph_engine): Merge error strategies into error_handler.py
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 01:49:46 +08:00
e060d7c28c refactor(graph_engine): remove Optional
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 01:49:15 +08:00
ea5dfe41d5 chore: ignore comment
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-10 01:36:11 +08:00
a23c8fcb1a refactor: move execution limits from engine core to layer
Remove max_execution_time and max_execution_steps from ExecutionContext and GraphEngine since these limits are now handled by ExecutionLimitsLayer. This follows the separation of concerns principle by keeping execution limits as a cross-cutting concern handled by layers rather than embedded in core engine components.

Changes:
- Remove max_execution_time and max_execution_steps from ExecutionContext
- Remove these parameters from GraphEngine.__init__()
- Remove max_execution_time from Dispatcher
- Update workflow_entry.py to no longer pass these parameters
- Update all tests to remove these parameters
2025-09-10 01:32:45 +08:00
e0e82fbfaa refactor: extract _run method into smaller focused methods in IterationNode
- Extract iterator variable retrieval and validation logic
- Separate empty iteration handling
- Create dedicated methods for iteration execution and result handling
- Improve type hints and use modern Python syntax
- Enhance code readability and maintainability
2025-09-10 01:15:36 +08:00
1c9f40f92a Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-09 22:16:59 +08:00
6ffa2ebabf feat: improve error handling in graph node creation
- Replace ValueError catch with generic Exception
- Use logger.exception for automatic traceback logging
- Abort on node creation failure instead of continuing
2025-09-09 22:16:42 +08:00
95dc1e2fe8 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-09 17:13:16 +08:00
6fe7cf5ebf Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-09 17:11:46 +08:00
a1e8ac4c96 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-09 15:49:09 +08:00
b46858d87d Merge branch 'main' into feat/queue-based-graph-engine 2025-09-09 13:33:17 +08:00
5ab6838849 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-08 19:55:43 +08:00
ef974e484b fix: handle None env vars
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-08 16:43:47 +08:00
299141ae01 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-08 13:56:45 +08:00
cc1d437dc1 fix: correct indentation in TokenBufferMemory get_history_prompt_messages method 2025-09-07 12:48:50 +08:00
7aef0b54e5 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-07 12:34:54 +08:00
3c28936796 fix: test
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-06 16:21:28 +08:00
81fdc7c54b fix: type errors
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-06 16:09:59 +08:00
abb53f11ad Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-06 16:05:13 +08:00
d9aa0ec046 fix: resolve mypy type errors in http_request and list_operator nodes
- Fix str | bytes union type handling in http_request executor
- Add type guard for boolean filter value in list_operator node
2025-09-05 21:17:18 +08:00
6c3302a192 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-05 21:13:07 +08:00
7ba1f0a046 chore: improve typing
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-05 20:57:11 +08:00
2adf5d0eee docs: remove outdated document 2025-09-05 02:09:53 +08:00
103a9a4e67 fix(graph_engine): add type hint for workers_to_remove 2025-09-05 01:59:11 +08:00
15b3443e9e fix(debug_logging_layer): remove access for variable pool 2025-09-05 01:52:19 +08:00
81e9d6f63a fix: correct type checking for None values in code node output validation
- Fixed isinstance() checks to properly handle None values by checking None separately
- Fixed typo in STRING type validation where 'output_name' was hardcoded as string instead of variable
- Updated error message format to be consistent and more informative
- Updated test assertion to match new error message format
2025-09-04 20:39:37 +08:00
9c2943183e test: fix code node
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 20:17:28 +08:00
f6a2a09815 test: fix code node
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 20:04:29 +08:00
e229510e73 perf: eliminate lock contention in worker pool by removing callbacks
Remove worker idle/active callbacks that caused severe lock contention.
Instead, use sampling-based monitoring where worker states are queried
on-demand during scaling decisions. This eliminates the performance
bottleneck caused by workers acquiring locks 10+ times per second.

Changes:
- Remove callback parameters from Worker class
- Add properties to expose worker idle state directly
- Update WorkerPool to query worker states without callbacks
- Maintain scaling functionality with better performance
2025-09-04 19:37:31 +08:00
36048d1526 feat(graph_engine): allow to scale down without lock
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 19:32:07 +08:00
aff7ca12b8 fix(code_node): type checking bypass
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 19:25:08 +08:00
ad9eed2551 fix: disable scale for perfermance
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 19:11:22 +08:00
07109846e0 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-04 17:48:08 +08:00
2aeaefccec test: fix test 2025-09-04 17:47:36 +08:00
4d63bd2083 refactor(graph_engine): rename SimpleWorkerPool to WorkerPool 2025-09-04 17:47:13 +08:00
226f14a20f feat(graph_engine): implement scale down worker
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 15:35:20 +08:00
2b28aed4e2 [autofix.ci] apply automated fixes 2025-09-04 04:50:21 +00:00
938a845852 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-04 12:48:58 +08:00
ead8568bfc fix: some errors reported by basedpyright
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 11:58:54 +08:00
ed22d04ea0 test: remove outdated test case 2025-09-04 02:42:36 +08:00
04bbf540d9 chore: code format
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 02:33:53 +08:00
657c27ec75 feat(graph_engine): make runtime state read-only in layer
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 02:30:40 +08:00
16e9cd5ac5 feat(graph_runtime_state): prevent to set variable pool after initialized. 2025-09-04 02:20:19 +08:00
61c79b0013 test: correct imported name 2025-09-04 02:15:46 +08:00
8332472944 refactor(graph_engine): rename Layer to GraphEngineLayer
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-04 02:11:31 +08:00
fe3f03e50a feat: add property-based access control to GraphRuntimeState
- Replace direct field access with private attributes and property decorators
- Implement deep copy protection for mutable objects (dict, LLMUsage)
- Add helper methods: set_output(), get_output(), update_outputs()
- Add increment_node_run_steps() and add_tokens() convenience methods
- Update loop_node and event_handlers to use new accessor methods
- Add comprehensive unit tests for immutability and validation
- Ensure backward compatibility with existing property access patterns
2025-09-04 02:08:58 +08:00
9c96b23d55 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-04 00:27:08 +08:00
8c97937cae Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-03 13:53:43 +08:00
f6acff4cce chore: remove unused variables
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-03 12:12:27 +08:00
3fa48cb5cf chore: remove ty-check from Python style check.
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-03 12:05:41 +08:00
b81745aed8 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-03 11:56:05 +08:00
8c41d95d03 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-03 11:06:42 +08:00
9d004a0971 test: fix test
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-03 02:11:37 +08:00
02fcd08c08 [autofix.ci] apply automated fixes 2025-09-02 17:34:07 +00:00
77a9a73d0d Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-03 01:33:17 +08:00
1770b93e5b chore(graph_engine): Add a TODO commment in _update_response_outputs in event_handlers
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-02 15:20:03 +08:00
d8ff4aa9ba feat(graph_engine): Handle NodeRunAgentLogEvent
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-02 15:02:07 +08:00
9f8f21bf87 chore: remove backup files
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-02 15:01:58 +08:00
0b0dc63f29 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-02 11:52:25 +08:00
8433cf4437 refactor(graph_engine): Merge event_collector and event_emitter into event_manager
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-01 13:15:58 +08:00
bb5d52539c refactor(graph_engine): Merge branch_handler into edge_processor
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-01 12:53:06 +08:00
88622f70fb refactor(graph_engine): Move setup methods into __init__
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-01 12:08:03 +08:00
0fdb1b2bc9 refactor(graph_engine): Correct private attributes and private methods naming
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-01 04:37:23 +08:00
a5cb9d2b73 refactor(graph_engine): inline output_registry into response_coordinator
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-01 03:59:53 +08:00
64c1234724 refactor(graph_engine): Merge worker management into one WorkerPool
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-01 03:23:47 +08:00
202fdfcb81 refactor(graph_engine): Remove backward compatibility code
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-01 02:41:16 +08:00
e2f4c9ba8d refactor(graph_engine): Merge state managers into unified_state_manager
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-01 02:08:08 +08:00
546d75d84d Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-09-01 00:29:28 +08:00
a8fe4ea802 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-08-30 16:36:10 +08:00
82193580de chore: improve typing
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-30 16:35:57 +08:00
1fd27cf3ad Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-08-30 00:13:45 +08:00
11d32ca87d test: fix web test
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-29 23:20:28 +08:00
5415d0c6d1 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-08-29 23:17:30 +08:00
d8af8ae4e6 fix: update workflow service tests for new graph engine
- Update method calls from _handle_node_run_result to _handle_single_step_result
- Add required fields (id, node_id, node_type, start_at) to graph events
- Use proper NodeType enum values instead of strings
- Fix imports to use correct modules (Node instead of BaseNode)
- Ensure event generators return proper generator objects

These tests were failing because the internal implementation changed
with the new graph engine architecture.
2025-08-29 23:04:33 +08:00
04e5d4692f Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-08-29 22:34:47 +08:00
3aa48efd0a test(test_workflow_service): Use new engine's method.
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-29 22:06:10 +08:00
8eb78c04b2 chore(token_buffer_memory): code format
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-29 17:02:51 +08:00
22ee318cf8 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-08-29 17:01:42 +08:00
f2bc4f5d87 fix: resolve type error in node_factory by using type guard for node_type_str 2025-08-29 16:16:58 +08:00
d7d456349d Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-08-29 16:14:04 +08:00
dce4d0ff80 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-08-29 13:22:13 +08:00
3dee8064ba feat: enhance typing 2025-08-29 13:17:02 +08:00
bfbb36756a feat(graph_engine): Add NodeExecutionType.ROOT and auto mark skipped in Graph.init
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 16:41:51 +08:00
d7e0c5f759 chore: use 'XXX | None' instead of Optional[XXX] in graph.py 2025-08-28 15:45:22 +08:00
c396788128 chore(graph_engine): add final mark to classes
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 15:38:35 +08:00
e3a7b1f691 fix: type hints
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 05:24:18 +08:00
8aab7f49c3 chore(graph_engine): Use XXX | None instead of Optional[XXX] 2025-08-28 05:09:33 +08:00
1e12c1cbf2 [autofix.ci] apply automated fixes 2025-08-27 21:00:36 +00:00
affedd6ce4 chore(graph_engine): Use XXX | None instead of Optional[XXX] 2025-08-28 04:59:49 +08:00
ef21097774 refactor(graph_engine): Remove unnecessary check from SkipPropagator
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 04:45:26 +08:00
1d377fe994 refactor(graph_engine): Use _ to mark unused variable in BranchHandler
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 04:44:45 +08:00
c82697f267 refactor(graph_engine): Remove node_id from SkipPropagator.skip_branch_paths
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 04:43:56 +08:00
98b25c0bbc refactor(graph_engine): Convert attrs to private in error_handler
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 04:42:37 +08:00
1cd0792606 chore(graph_events): Improve type hints
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 04:41:48 +08:00
7cbf4093f4 chore(graph_engine): Use TYPE | None instead of Optional
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 04:30:50 +08:00
8129ca7c05 chore(graph_engine): Move error_strategy.py to protocols/
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 04:29:32 +08:00
65617f000d feat(event_collector): Update to use ReadWriteLock 2025-08-28 03:26:42 +08:00
635eff2e25 test(graph_engine): remove outdated tests
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 02:53:19 +08:00
55085a9ca2 chore(graph_engine): add type hint for event_queue
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-28 02:38:56 +08:00
9dc1e9724e Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-08-28 02:26:40 +08:00
c3f66e2901 Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine 2025-08-27 18:05:35 +08:00
86e7cb713c [autofix.ci] apply automated fixes 2025-08-27 07:38:26 +00:00
0f29244459 fix: test
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-27 15:37:37 +08:00
48cbf4c78f [autofix.ci] apply automated fixes 2025-08-27 15:33:30 +08:00
8c35663220 feat: queue-based graph engine
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-27 15:33:28 +08:00
368 changed files with 22582 additions and 11962 deletions

View File

@ -12,7 +12,6 @@ permissions:
statuses: write
contents: read
jobs:
python-style:
name: Python Style
@ -44,6 +43,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev
- name: Run Import Linter
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --directory api --dev lint-imports
- name: Run Basedpyright Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check

View File

@ -461,6 +461,16 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# GraphEngine Worker Pool Configuration
# Minimum number of workers per GraphEngine instance (default: 1)
GRAPH_ENGINE_MIN_WORKERS=1
# Maximum number of workers per GraphEngine instance (default: 10)
GRAPH_ENGINE_MAX_WORKERS=10
# Queue depth threshold that triggers worker scale up (default: 3)
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
# Seconds of idle time before scaling down workers (default: 5.0)
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)

105
api/.importlinter Normal file
View File

@ -0,0 +1,105 @@
[importlinter]
root_packages =
core
configs
controllers
models
tasks
services
[importlinter:contract:workflow]
name = Workflow
type=layers
layers =
graph_engine
graph_events
graph
nodes
node_events
entities
containers =
core.workflow
ignore_imports =
core.workflow.nodes.base.node -> core.workflow.graph_events
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
core.workflow.nodes.node_factory -> core.workflow.graph
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
[importlinter:contract:rsc]
name = RSC
type = layers
layers =
graph_engine
response_coordinator
containers =
core.workflow.graph_engine
[importlinter:contract:worker]
name = Worker
type = layers
layers =
graph_engine
worker
containers =
core.workflow.graph_engine
[importlinter:contract:graph-engine-architecture]
name = Graph Engine Architecture
type = layers
layers =
graph_engine
orchestration
command_processing
event_management
error_handler
graph_traversal
graph_state_manager
worker_management
domain
containers =
core.workflow.graph_engine
[importlinter:contract:domain-isolation]
name = Domain Model Isolation
type = forbidden
source_modules =
core.workflow.graph_engine.domain
forbidden_modules =
core.workflow.graph_engine.worker_management
core.workflow.graph_engine.command_channels
core.workflow.graph_engine.layers
core.workflow.graph_engine.protocols
[importlinter:contract:worker-management]
name = Worker Management
type = forbidden
source_modules =
core.workflow.graph_engine.worker_management
forbidden_modules =
core.workflow.graph_engine.orchestration
core.workflow.graph_engine.command_processing
core.workflow.graph_engine.event_management
[importlinter:contract:graph-traversal-components]
name = Graph Traversal Components
type = layers
layers =
edge_processor
skip_propagator
containers =
core.workflow.graph_engine.graph_traversal
[importlinter:contract:command-channels]
name = Command Channels Independence
type = independence
modules =
core.workflow.graph_engine.command_channels.in_memory_channel
core.workflow.graph_engine.command_channels.redis_channel

View File

@ -14,7 +14,6 @@ from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from constants.languages import languages
from core.plugin.entities.plugin import ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
@ -32,6 +31,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
from models.provider_ids import ToolProviderID
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs

View File

@ -535,6 +535,28 @@ class WorkflowConfig(BaseSettings):
default=200 * 1024,
)
# GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance",
default=1,
)
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
description="Maximum number of workers per GraphEngine instance",
default=10,
)
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
description="Queue depth threshold that triggers worker scale up",
default=3,
)
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
description="Seconds of idle time before scaling down workers",
default=5.0,
ge=0.1,
)
class WorkflowNodeExecutionConfig(BaseSettings):
"""

View File

@ -29,7 +29,7 @@ def no_key_cache_key(namespace: str, key: str) -> str:
# Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any:
if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None:

View File

@ -16,7 +16,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.login import login_required
from models import App
from services.workflow_service import WorkflowService
@console_ns.route("/rule-generate")
@ -205,9 +208,6 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
from models import App, db
from services.workflow_service import WorkflowService
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app:
return {"error": f"app {args['flow_id']} not found"}, 400

View File

@ -20,6 +20,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@ -536,7 +537,12 @@ class WorkflowTaskStopApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required

View File

@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode, db
from models import App, AppMode
from models.account import Account
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService

View File

@ -20,7 +20,6 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
@ -33,6 +32,7 @@ from fields.document_fields import document_status_fields
from libs.login import login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService

View File

@ -20,6 +20,7 @@ from core.errors.error import (
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.login import current_user
from models.model import AppMode, InstalledApp
@ -82,6 +83,11 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
raise NotWorkflowAppError()
assert current_user is not None
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService

View File

@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db
from extensions.ext_database import db as global_db
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")

View File

@ -26,7 +26,8 @@ from core.errors.error import (
)
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
@ -262,7 +263,12 @@ class WorkflowTaskStopApi(Resource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user
from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService

View File

@ -21,6 +21,7 @@ from core.errors.error import (
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
@ -112,6 +113,11 @@ class WorkflowTaskStopApi(WebApiResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@ -4,8 +4,8 @@ from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from models.provider_ids import ModelProviderID
class ModelConfigManager:

View File

@ -1,11 +1,11 @@
import logging
import time
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable, WorkflowType
from models.workflow import ConversationVariable
logger = logging.getLogger(__name__)
@ -78,23 +79,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
inputs = self.application_generate_entity.inputs
@ -146,16 +153,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
# init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
)
db.session.close()
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
@ -167,11 +185,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
generator = workflow_entry.run(
callbacks=workflow_callbacks,
)
generator = workflow_entry.run()
for event in generator:
self._handle_event(workflow_entry, event)

View File

@ -31,14 +31,9 @@ from core.app.entities.queue_entities import (
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@ -65,8 +60,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
@ -387,9 +382,7 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_node_failed_events(
self,
event: Union[
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
@ -434,32 +427,6 @@ class AdvancedChatAppGenerateTaskPipeline:
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
@ -751,8 +718,6 @@ class AdvancedChatAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
@ -800,8 +765,6 @@ class AdvancedChatAppGenerateTaskPipeline:
event,
(
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
),
):
@ -814,17 +777,6 @@ class AdvancedChatAppGenerateTaskPipeline:
)
return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# For unhandled events, we continue (original behavior)
return
@ -848,11 +800,6 @@ class AdvancedChatAppGenerateTaskPipeline:
graph_runtime_state = event.graph_runtime_state
yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent():
yield from self._handle_text_chunk_event(
event, tts_publisher=tts_publisher, queue_message=queue_message
)
case QueueErrorEvent():
yield from self._handle_error_event(event)
break

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaver,
DraftVariableSaverFactory,

View File

@ -127,6 +127,21 @@ class AppQueueManager:
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
@classmethod
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
"""
Set task stop flag without user permission check.
This method allows stopping workflows without user context.
:param task_id: The task ID to stop
:return:
"""
if not task_id:
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
def _is_stopped(self) -> bool:
"""
Check if task is stopped

View File

@ -1,7 +1,7 @@
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Union, cast
from typing import Any, Union
from sqlalchemy.orm import Session
@ -16,14 +16,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
@ -36,18 +31,16 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
@ -171,11 +164,10 @@ class WorkflowResponseConverter:
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
)
return response
@ -183,11 +175,7 @@ class WorkflowResponseConverter:
def workflow_node_finish_to_stream_response(
self,
*,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> NodeFinishStreamResponse | None:
@ -221,9 +209,6 @@ class WorkflowResponseConverter:
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
@ -275,50 +260,6 @@ class WorkflowResponseConverter:
),
)
def workflow_parallel_branch_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunStartedEvent,
) -> ParallelBranchStartStreamResponse:
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=int(time.time()),
),
)
def workflow_parallel_branch_finished_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
),
)
def workflow_iteration_start_to_stream_response(
self,
*,
@ -333,13 +274,11 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
@ -357,15 +296,10 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
@ -384,8 +318,8 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=json_converter.to_json_encodable(event.outputs),
title=event.node_title,
outputs=json_converter.to_json_encodable(event.outputs) or {},
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@ -394,12 +328,10 @@ class WorkflowResponseConverter:
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
@ -413,7 +345,7 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@ -437,7 +369,7 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
index=event.index,
pre_loop_output=event.output,
created_at=int(time.time()),
@ -445,7 +377,6 @@ class WorkflowResponseConverter:
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
@ -463,8 +394,8 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
title=event.node_title,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs) or {},
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@ -473,7 +404,7 @@ class WorkflowResponseConverter:
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,

View File

@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: str | None,
) -> Generator[Mapping | str, None, None]: ...
@overload
@ -67,7 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: str | None,
) -> Mapping[str, Any]: ...
@overload
@ -81,7 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: str | None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
@ -94,7 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@ -186,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
def _generate(
@ -200,7 +195,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: str | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
@ -214,7 +208,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
@ -237,7 +230,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
},
)
@ -434,17 +426,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: str | None = None,
):
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
) -> None:
with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
@ -474,7 +456,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,

View File

@ -1,7 +1,7 @@
import logging
import time
from typing import cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from models.enums import UserFrom
from models.workflow import Workflow, WorkflowType
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -31,7 +32,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: str | None = None,
workflow: Workflow,
system_user_id: str,
):
@ -41,7 +41,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
@ -52,24 +51,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
)
else:
inputs = self.application_generate_entity.inputs
@ -92,15 +97,27 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict)
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
)
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
@ -112,10 +129,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
generator = workflow_entry.run(callbacks=workflow_callbacks)
generator = workflow_entry.run()
for event in generator:
self._handle_event(workflow_entry, event)

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Any, Union
from typing import Union
from sqlalchemy.orm import Session
@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -349,9 +345,7 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_node_failed_events(
self,
event: Union[
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
@ -370,32 +364,6 @@ class WorkflowAppGenerateTaskPipeline:
if node_failed_response:
yield node_failed_response
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
@ -617,8 +585,6 @@ class WorkflowAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
@ -633,7 +599,7 @@ class WorkflowAppGenerateTaskPipeline:
def _dispatch_event(
self,
event: Any,
event: AppQueueEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
@ -660,8 +626,6 @@ class WorkflowAppGenerateTaskPipeline:
event,
(
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
),
):
@ -674,17 +638,6 @@ class WorkflowAppGenerateTaskPipeline:
)
return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle workflow failed and stop events with isinstance check
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events(

View File

@ -2,6 +2,7 @@ from collections.abc import Mapping
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@ -13,14 +14,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
@ -28,42 +24,39 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeInIterationFailedEvent,
NodeInLoopFailedEvent,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow
@ -79,7 +72,14 @@ class WorkflowBasedAppRunner:
self._variable_loader = variable_loader
self._app_id = app_id
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
) -> Graph:
"""
Init graph
"""
@ -91,8 +91,28 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=tenant_id or "",
app_id=self._app_id,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
# Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
if not graph:
raise ValueError("graph not found in workflow")
@ -104,6 +124,7 @@ class WorkflowBasedAppRunner:
workflow: Workflow,
node_id: str,
user_inputs: dict,
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
@ -145,8 +166,25 @@ class WorkflowBasedAppRunner:
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
@ -201,6 +239,7 @@ class WorkflowBasedAppRunner:
workflow: Workflow,
node_id: str,
user_inputs: dict,
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
@ -242,8 +281,25 @@ class WorkflowBasedAppRunner:
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
@ -310,39 +366,32 @@ class WorkflowBasedAppRunner:
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result
inputs: Mapping[str, Any] | None = {}
process_data: Mapping[str, Any] | None = {}
outputs: Mapping[str, Any] | None = {}
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
parallel_mode_run_id=event.parallel_mode_run_id,
inputs=inputs,
process_data=process_data,
outputs=outputs,
error=event.error,
execution_metadata=execution_metadata,
retry_index=event.retry_index,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunStartedEvent):
@ -350,44 +399,29 @@ class WorkflowBasedAppRunner:
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
start_at=event.start_at,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
parallel_mode_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
node_run_result = event.route_node_state.node_run_result
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
start_at=event.start_at,
inputs=inputs,
process_data=process_data,
outputs=outputs,
@ -396,34 +430,18 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
@ -434,93 +452,21 @@ class WorkflowBasedAppRunner:
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeInLoopFailedEvent):
self._publish_event(
QueueNodeInLoopFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk_content,
from_variable_selector=event.from_variable_selector,
text=event.chunk,
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
@ -533,10 +479,10 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, AgentLogEvent):
elif isinstance(event, NodeRunAgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.id,
id=event.message_id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
@ -547,51 +493,13 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, IterationRunStartedEvent):
elif isinstance(event, NodeRunIterationStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
@ -599,55 +507,41 @@ class WorkflowBasedAppRunner:
metadata=event.metadata,
)
)
elif isinstance(event, IterationRunNextEvent):
elif isinstance(event, NodeRunIterationNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
)
)
elif isinstance(event, LoopRunStartedEvent):
elif isinstance(event, NodeRunLoopStartedEvent):
self._publish_event(
QueueLoopStartEvent(
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
@ -655,42 +549,32 @@ class WorkflowBasedAppRunner:
metadata=event.metadata,
)
)
elif isinstance(event, LoopRunNextEvent):
elif isinstance(event, NodeRunLoopNextEvent):
self._publish_event(
QueueLoopNextEvent(
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_loop_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
)
)
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
self._publish_event(
QueueLoopCompletedEvent(
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
)
)

View File

@ -3,15 +3,13 @@ from datetime import datetime
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class QueueEvent(StrEnum):
@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error"
PING = "ping"
@ -80,21 +75,13 @@ class QueueIterationStartEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
node_title: str
start_at: datetime
node_run_index: int
inputs: Mapping[str, Any] | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, Any] | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
class QueueIterationNextEvent(AppQueueEvent):
@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
node_title: str
node_run_index: int
output: Any | None = None # output for the current iteration
duration: float | None = None
output: Any = None # output for the current iteration
class QueueIterationCompletedEvent(AppQueueEvent):
@ -134,21 +110,13 @@ class QueueIterationCompletedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
node_title: str
start_at: datetime
node_run_index: int
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: str | None = None
@ -163,7 +131,7 @@ class QueueLoopStartEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
@ -175,9 +143,9 @@ class QueueLoopStartEvent(AppQueueEvent):
start_at: datetime
node_run_index: int
inputs: Mapping[str, Any] | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, Any] | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
class QueueLoopNextEvent(AppQueueEvent):
@ -191,7 +159,7 @@ class QueueLoopNextEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
@ -203,8 +171,7 @@ class QueueLoopNextEvent(AppQueueEvent):
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
node_run_index: int
output: Any | None = None # output for the current loop
duration: float | None = None
output: Any = None # output for the current loop
class QueueLoopCompletedEvent(AppQueueEvent):
@ -217,7 +184,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
@ -229,9 +196,9 @@ class QueueLoopCompletedEvent(AppQueueEvent):
start_at: datetime
node_run_index: int
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: str | None = None
@ -332,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: dict[str, Any] | None = None
outputs: Mapping[str, object] = Field(default_factory=dict)
class QueueWorkflowFailedEvent(AppQueueEvent):
@ -352,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: dict[str, Any] | None = None
outputs: Mapping[str, object] = Field(default_factory=dict)
class QueueNodeStartedEvent(AppQueueEvent):
@ -364,27 +331,24 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_title: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int = 1
node_run_index: int = 1 # FIXME(-LAN-): may not used
predecessor_node_id: str | None = None
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
agent_strategy: AgentNodeStrategyInit | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
provider_id: str
class QueueNodeSucceededEvent(AppQueueEvent):
"""
@ -396,7 +360,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
@ -411,16 +374,12 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""loop id if node is in loop"""
start_at: datetime
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str | None = None
"""single iteration duration map"""
iteration_duration_map: dict[str, float] | None = None
"""single loop duration map"""
loop_duration_map: dict[str, float] | None = None
class QueueAgentLogEvent(AppQueueEvent):
@ -436,7 +395,7 @@ class QueueAgentLogEvent(AppQueueEvent):
error: str | None = None
status: str
data: Mapping[str, Any]
metadata: Mapping[str, Any] | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
node_id: str
@ -445,81 +404,15 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
event: QueueEvent = QueueEvent.RETRY
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
retry_index: int # retry index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
class QueueNodeInLoopFailedEvent(AppQueueEvent):
"""
QueueNodeInLoopFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
@ -530,7 +423,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
@ -545,9 +437,9 @@ class QueueNodeExceptionEvent(AppQueueEvent):
"""loop id if node is in loop"""
start_at: datetime
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
@ -563,24 +455,16 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
@ -610,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.ERROR
error: Any | None = None
error: Any = None
class QueuePingEvent(AppQueueEvent):
@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
"""
pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
error: str

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
@ -138,7 +138,7 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: dict = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
files: Sequence[Mapping[str, Any]] | None = None
@ -175,7 +175,7 @@ class AgentThoughtStreamResponse(StreamResponse):
thought: str | None = None
observation: str | None = None
tool: str | None = None
tool_labels: dict | None = None
tool_labels: Mapping[str, object] = Field(default_factory=dict)
tool_input: str | None = None
message_files: list[str] | None = None
@ -228,7 +228,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
elapsed_time: float
total_tokens: int
total_steps: int
created_by: dict | None = None
created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int
finished_at: int
exceptions_count: int | None = 0
@ -257,7 +257,7 @@ class NodeStartStreamResponse(StreamResponse):
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
@ -436,54 +436,6 @@ class NodeRetryStreamResponse(StreamResponse):
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
status: str
error: str | None = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
@ -499,11 +451,9 @@ class IterationNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: str | None = None
parallel_start_node_id: str | None = None
extras: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
inputs: Mapping[str, object] = Field(default_factory=dict)
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
@ -526,12 +476,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_iteration_output: Any | None = None
extras: dict = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parallel_mode_run_id: str | None = None
duration: float | None = None
extras: Mapping[str, object] = Field(default_factory=dict)
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -552,19 +497,17 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Mapping | None = None
outputs: Mapping[str, object] = Field(default_factory=dict)
created_at: int
extras: dict | None = None
inputs: Mapping | None = None
extras: Mapping[str, object] = Field(default_factory=dict)
inputs: Mapping[str, object] = Field(default_factory=dict)
status: WorkflowNodeExecutionStatus
error: str | None = None
elapsed_time: float
total_tokens: int
execution_metadata: Mapping | None = None
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int
steps: int
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
@ -586,9 +529,9 @@ class LoopNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
extras: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
inputs: Mapping[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
@ -613,12 +556,11 @@ class LoopNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_loop_output: Any | None = None
extras: dict = Field(default_factory=dict)
pre_loop_output: Any = None
extras: Mapping[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parallel_mode_run_id: str | None = None
duration: float | None = None
event: StreamEvent = StreamEvent.LOOP_NEXT
workflow_run_id: str
@ -639,15 +581,15 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Mapping | None = None
outputs: Mapping[str, object] = Field(default_factory=dict)
created_at: int
extras: dict | None = None
inputs: Mapping | None = None
extras: Mapping[str, object] = Field(default_factory=dict)
inputs: Mapping[str, object] = Field(default_factory=dict)
status: WorkflowNodeExecutionStatus
error: str | None = None
elapsed_time: float
total_tokens: int
execution_metadata: Mapping | None = None
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int
steps: int
parallel_id: str | None = None
@ -757,7 +699,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
conversation_id: str
message_id: str
answer: str
metadata: dict = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
data: Data
@ -777,7 +719,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
mode: str
message_id: str
answer: str
metadata: dict = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
data: Data
@ -825,7 +767,7 @@ class AgentLogStreamResponse(StreamResponse):
error: str | None = None
status: str
data: Mapping[str, Any]
metadata: Mapping[str, Any] | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
node_id: str
event: StreamEvent = StreamEvent.AGENT_LOG

View File

@ -109,7 +109,9 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueNodeSucceededEvent):
if message.event.outputs is None:
continue
self.msg_text += message.event.outputs.get("output", "")
output = message.event.outputs.get("output", "")
if isinstance(output, str):
self.msg_text += output
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.max_sentence, 7):
@ -119,7 +121,7 @@ class AppGeneratorTTSPublisher:
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
if text_tmp:
if isinstance(text_tmp, str):
self.msg_text = text_tmp
else:
self.msg_text = ""

View File

@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import (
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import (
@ -41,6 +40,7 @@ from models.provider import (
ProviderType,
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)
@ -704,6 +704,7 @@ class ProviderConfiguration(BaseModel):
Get custom model credentials.
"""
# get provider model
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
@ -1203,6 +1204,7 @@ class ProviderConfiguration(BaseModel):
"""
Get provider model setting.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
@ -1286,6 +1288,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():

View File

@ -1,9 +1,33 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import TypedDict
from pydantic import BaseModel
class CodeNodeProvider(BaseModel):
class VariableConfig(TypedDict):
variable: str
value_selector: Sequence[str | int]
class OutputConfig(TypedDict):
type: str
children: None
class CodeConfig(TypedDict):
variables: Sequence[VariableConfig]
code_language: str
code: str
outputs: Mapping[str, OutputConfig]
class DefaultConfig(TypedDict):
type: str
config: CodeConfig
class CodeNodeProvider(BaseModel, ABC):
@staticmethod
@abstractmethod
def get_language() -> str:
@ -22,11 +46,14 @@ class CodeNodeProvider(BaseModel):
pass
@classmethod
def get_default_config(cls):
def get_default_config(cls) -> DefaultConfig:
return {
"type": "code",
"config": {
"variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
"variables": [
{"variable": "arg1", "value_selector": []},
{"variable": "arg2", "value_selector": []},
],
"code_language": cls.get_language(),
"code": cls.get_default_code(),
"outputs": {"result": {"type": "string", "children": None}},

View File

@ -16,8 +16,8 @@ def full_mask_token(token_length=20):
def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db
from models.account import Tenant
from models.engine import db
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")

View File

@ -28,8 +28,9 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from models import App, Message, WorkflowNodeExecutionModel, db
from core.workflow.node_events import AgentLogEvent
from extensions.ext_database import db
from models import App, Message, WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)

View File

@ -160,7 +160,7 @@ class ErrorData(BaseModel):
sentence.
"""
data: Any | None = None
data: Any = None
"""
Additional information about the error. The value of this member is defined by the
sender (e.g. detailed error information, nested errors etc.).

View File

@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
from core.plugin.impl.model import PluginModelClient
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
class AIModel(BaseModel):
@ -52,6 +51,8 @@ class AIModel(BaseModel):
:return: Invoke error mapping
"""
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
@ -139,6 +140,8 @@ class AIModel(BaseModel):
:param credentials: model credentials
:return: model schema
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
# sort credentials

View File

@ -22,7 +22,6 @@ from core.model_runtime.entities.model_entities import (
PriceType,
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@ -142,6 +141,8 @@ class LargeLanguageModel(AIModel):
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
result = plugin_model_manager.invoke_llm(
tenant_id=self.tenant_id,
@ -340,6 +341,8 @@ class LargeLanguageModel(AIModel):
:return:
"""
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id,

View File

@ -4,7 +4,6 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
class ModerationModel(AIModel):
@ -30,6 +29,8 @@ class ModerationModel(AIModel):
self.started_at = time.perf_counter()
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_moderation(
tenant_id=self.tenant_id,

View File

@ -1,7 +1,6 @@
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
class RerankModel(AIModel):
@ -34,6 +33,8 @@ class RerankModel(AIModel):
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_rerank(
tenant_id=self.tenant_id,

View File

@ -4,7 +4,6 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
class Speech2TextModel(AIModel):
@ -28,6 +27,8 @@ class Speech2TextModel(AIModel):
:return: text for given audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_speech_to_text(
tenant_id=self.tenant_id,

View File

@ -4,7 +4,6 @@ from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
class TextEmbeddingModel(AIModel):
@ -35,6 +34,8 @@ class TextEmbeddingModel(AIModel):
:param input_type: input type
:return: embeddings result
"""
from core.plugin.impl.model import PluginModelClient
try:
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_text_embedding(
@ -59,6 +60,8 @@ class TextEmbeddingModel(AIModel):
:param texts: texts to embed
:return:
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,

View File

@ -5,7 +5,6 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@ -41,6 +40,8 @@ class TTSModel(AIModel):
:return: translated audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_tts(
tenant_id=self.tenant_id,
@ -64,6 +65,8 @@ class TTSModel(AIModel):
:param credentials: The credentials required to access the TTS model.
:return: A list of voices supported by the TTS model.
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_tts_model_voices(
tenant_id=self.tenant_id,

View File

@ -15,16 +15,16 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.plugin.entities.plugin import ModelProviderID
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
class ModelProviderFactory:
def __init__(self, tenant_id: str):
from core.plugin.impl.model import PluginModelClient
self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelClient()
@ -38,7 +38,7 @@ class ModelProviderFactory:
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
"""
Get all plugin model providers
:return: list of plugin model providers
@ -76,7 +76,7 @@ class ModelProviderFactory:
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
"""
Get plugin model provider
:param provider: provider name
@ -331,6 +331,8 @@ class ModelProviderFactory:
mime_type = image_mime_types.get(extension, "image/png")
# get icon bytes from plugin asset manager
from core.plugin.impl.asset import PluginAssetManager
plugin_asset_manager = PluginAssetManager()
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
@ -340,5 +342,6 @@ class ModelProviderFactory:
:param provider: provider name
:return: plugin id and provider name
"""
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name

View File

@ -54,13 +54,10 @@ from core.ops.entities.trace_entity import (
)
from core.rag.models.document import Document
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes import NodeType
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from extensions.ext_database import db
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)

View File

@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
)
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus

View File

@ -28,8 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@ -22,8 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@ -6,7 +6,7 @@ import queue
import threading
import time
from datetime import timedelta
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import UUID, uuid4
from cachetools import LRUCache
@ -31,13 +31,15 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from core.workflow.entities.workflow_execution import WorkflowExecution
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
from core.workflow.entities import WorkflowExecution
logger = logging.getLogger(__name__)
@ -407,7 +409,7 @@ class TraceTask:
self,
trace_type: Any,
message_id: str | None = None,
workflow_execution: WorkflowExecution | None = None,
workflow_execution: Optional["WorkflowExecution"] = None,
conversation_id: str | None = None,
user_id: str | None = None,
timer: Any | None = None,

View File

@ -23,8 +23,7 @@ from core.ops.entities.trace_entity import (
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@ -167,7 +167,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
call_depth=1,
workflow_thread_pool_id=None,
)
@classmethod

View File

@ -1,5 +1,5 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)

View File

@ -6,7 +6,6 @@ from pydantic import BaseModel, Field, field_validator
from core.entities.parameter_entities import CommonParameterType
from core.tools.entities.common_entities import I18nObject
from core.workflow.nodes.base.entities import NumberType
class PluginParameterOption(BaseModel):
@ -153,7 +152,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
raise ValueError("The tools selector must be a list.")
return value
case PluginParameterType.ANY:
if value and not isinstance(value, str | dict | list | NumberType):
if value and not isinstance(value, str | dict | list | int | float):
raise ValueError("The var selector must be a string, dictionary, list or number.")
return value
case PluginParameterType.ARRAY:

View File

@ -1,12 +1,10 @@
import datetime
import re
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Any
from packaging.version import InvalidVersion, Version
from pydantic import BaseModel, Field, field_validator, model_validator
from werkzeug.exceptions import NotFound
from core.agent.plugin_entities import AgentStrategyProviderEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
@ -156,55 +154,6 @@ class PluginEntity(PluginInstallation):
return self
class GenericProviderID:
organization: str
plugin_name: str
provider_name: str
is_hardcoded: bool
def to_string(self) -> str:
return str(self)
def __str__(self) -> str:
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
def __init__(self, value: str, is_hardcoded: bool = False):
if not value:
raise NotFound("plugin not found, please add plugin")
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
if re.match(r"^[a-z0-9_-]+$", value):
value = f"langgenius/{value}/{value}"
else:
raise ValueError(f"Invalid plugin id {value}")
self.organization, self.plugin_name, self.provider_name = value.split("/")
self.is_hardcoded = is_hardcoded
def is_langgenius(self) -> bool:
return self.organization == "langgenius"
@property
def plugin_id(self) -> str:
return f"{self.organization}/{self.plugin_name}"
class ModelProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False):
super().__init__(value, is_hardcoded)
if self.organization == "langgenius" and self.provider_name == "google":
self.plugin_name = "gemini"
class ToolProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False):
super().__init__(value, is_hardcoded)
if self.organization == "langgenius":
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
self.plugin_name = f"{self.provider_name}_tool"
class PluginDependency(BaseModel):
class Type(StrEnum):
Github = PluginInstallationSource.Github

View File

@ -2,13 +2,13 @@ from collections.abc import Generator
from typing import Any
from core.agent.entities import AgentInvokeMessage
from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.chunk_merger import merge_blob_chunks
from models.provider_ids import GenericProviderID
class PluginAgentClient(BasePluginClient):

View File

@ -1,9 +1,9 @@
from collections.abc import Mapping
from typing import Any
from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
from core.plugin.impl.base import BasePluginClient
from models.provider_ids import GenericProviderID
class DynamicSelectClient(BasePluginClient):

View File

@ -2,7 +2,6 @@ from collections.abc import Sequence
from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import (
GenericProviderID,
MissingPluginDependency,
PluginDeclaration,
PluginEntity,
@ -16,6 +15,7 @@ from core.plugin.entities.plugin_daemon import (
PluginListResponse,
)
from core.plugin.impl.base import BasePluginClient
from models.provider_ids import GenericProviderID
class PluginInstaller(BasePluginClient):

View File

@ -3,11 +3,11 @@ from typing import Any
from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.chunk_merger import merge_blob_chunks
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
from models.provider_ids import GenericProviderID, ToolProviderID
class PluginToolManager(BasePluginClient):

View File

@ -36,7 +36,6 @@ from core.model_runtime.entities.provider_entities import (
ProviderEntity,
)
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from extensions import ext_hosting_provider
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -51,6 +50,7 @@ from models.provider import (
TenantDefaultModel,
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.feature_service import FeatureService

View File

@ -1,9 +1,9 @@
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from configs import dify_config
from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.splitter.fixed_text_splitter import (
@ -13,6 +13,9 @@ from core.rag.splitter.fixed_text_splitter import (
from core.rag.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule
if TYPE_CHECKING:
from core.model_manager import ModelInstance
class BaseIndexProcessor(ABC):
"""Interface for extract files."""
@ -51,7 +54,7 @@ class BaseIndexProcessor(ABC):
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: ModelInstance | None,
embedding_model_instance: Optional["ModelInstance"],
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.

View File

@ -9,11 +9,8 @@ from typing import Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import (
WorkflowExecution,
WorkflowExecutionStatus,
WorkflowType,
)
from core.workflow.entities import WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.helper import extract_tenant_id
@ -203,5 +200,4 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
session.commit()
# Update the in-memory cache for faster subsequent lookups
logger.debug("Updating cache for execution_id: %s", db_model.id)
self._execution_cache[db_model.id] = db_model

View File

@ -15,12 +15,8 @@ from sqlalchemy.orm import sessionmaker
from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.helper import extract_tenant_id
@ -245,7 +241,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Update the in-memory cache after successful save
if db_model.node_execution_id:
logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id)
self._node_execution_cache[db_model.node_execution_id] = db_model
except Exception:

View File

@ -1,5 +1,6 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -16,10 +17,10 @@ class ToolApiEntity(BaseModel):
description: I18nObject
parameters: list[ToolParameter] | None = None
labels: list[str] = Field(default_factory=list)
output_schema: dict | None = None
output_schema: Mapping[str, object] = Field(default_factory=dict)
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]]
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow", "mcp"] | None
class ToolProviderApiEntity(BaseModel):
@ -27,17 +28,17 @@ class ToolProviderApiEntity(BaseModel):
author: str
name: str # identifier
description: I18nObject
icon: str | dict
icon_dark: str | dict | None = Field(default=None, description="The dark icon of the tool")
icon: str | Mapping[str, str]
icon_dark: str | Mapping[str, str] = ""
label: I18nObject # label
type: ToolProviderType
masked_credentials: dict | None = None
original_credentials: dict | None = None
masked_credentials: Mapping[str, object] = Field(default_factory=dict)
original_credentials: Mapping[str, object] = Field(default_factory=dict)
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: str | None = Field(default="", description="The plugin id of the tool")
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool")
tools: list[ToolApiEntity] = Field(default_factory=list)
tools: list[ToolApiEntity] = Field(default_factory=list[ToolApiEntity])
labels: list[str] = Field(default_factory=list)
# MCP
server_url: str | None = Field(default="", description="The server url of the tool")
@ -105,7 +106,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
is_default: bool = Field(
default=False, description="Whether the credential is the default credential for the provider in the workspace"
)
credentials: dict = Field(description="The credentials of the provider")
credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict)
class ToolProviderCredentialInfoApiEntity(BaseModel):

View File

@ -186,7 +186,7 @@ class ToolInvokeMessage(BaseModel):
error: str | None = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Mapping[str, Any] | None = Field(default=None, description="The metadata of the log")
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
class RetrieverResourceMessage(BaseModel):
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
@ -362,9 +362,9 @@ class ToolDescription(BaseModel):
class ToolEntity(BaseModel):
identity: ToolIdentity
parameters: list[ToolParameter] = Field(default_factory=list)
parameters: list[ToolParameter] = Field(default_factory=list[ToolParameter])
description: ToolDescription | None = None
output_schema: dict | None = None
output_schema: Mapping[str, object] = Field(default_factory=dict)
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
# pydantic configs
@ -377,21 +377,23 @@ class ToolEntity(BaseModel):
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
client_schema: list[ProviderConfig] = Field(
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
)
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: str | None = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
credentials_schema: list[ProviderConfig] = Field(default_factory=list[ProviderConfig])
oauth_schema: OAuthSchema | None = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
tools: list[ToolEntity] = Field(default_factory=list)
tools: list[ToolEntity] = Field(default_factory=list[ToolEntity])
class WorkflowToolParameterConfiguration(BaseModel):

View File

@ -72,7 +72,6 @@ class MCPToolProviderController(ToolProviderController):
),
llm=remote_mcp_tool.description or "",
),
output_schema=None,
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
)
for remote_mcp_tool in remote_mcp_tools

View File

@ -152,7 +152,6 @@ class ToolEngine:
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
thread_pool_id: str | None = None,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
@ -166,7 +165,6 @@ class ToolEngine:
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
tool.thread_pool_id = thread_pool_id
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Union, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa
from pydantic import TypeAdapter
@ -14,32 +14,17 @@ from sqlalchemy.orm import Session
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.tool import BuiltinTool
@ -55,14 +40,28 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.workflow.entities.variable_pool import VariablePool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
from core.workflow.nodes.tool.entities import ToolEntity
logger = logging.getLogger(__name__)
@ -117,6 +116,7 @@ class ToolManager:
get the plugin provider
"""
# check if context is set
try:
contexts.plugin_tool_providers.get()
except LookupError:
@ -172,6 +172,7 @@ class ToolManager:
:return: the tool
"""
if provider_type == ToolProviderType.BUILT_IN:
# check if the builtin tool need credentials
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
@ -213,16 +214,16 @@ class ToolManager:
# fallback to the default provider
if builtin_provider is None:
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
with Session(db.engine) as session:
builtin_provider = session.scalar(
sa.select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
@ -263,6 +264,7 @@ class ToolManager:
# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from core.plugin.impl.oauth import OAuthHandler
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
@ -270,6 +272,7 @@ class ToolManager:
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
@ -358,7 +361,7 @@ class ToolManager:
app_id: str,
agent_tool: AgentToolEntity,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: VariablePool | None = None,
variable_pool: Optional["VariablePool"] = None,
) -> Tool:
"""
get the agent tool runtime
@ -400,7 +403,7 @@ class ToolManager:
node_id: str,
workflow_tool: "ToolEntity",
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: VariablePool | None = None,
variable_pool: Optional["VariablePool"] = None,
) -> Tool:
"""
get the workflow tool runtime
@ -516,6 +519,7 @@ class ToolManager:
"""
list all the plugin providers
"""
manager = PluginToolManager()
provider_entities = manager.fetch_tool_providers(tenant_id)
return [
@ -882,7 +886,7 @@ class ToolManager:
)
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str):
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
@ -893,13 +897,13 @@ class ToolManager:
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
icon: dict = json.loads(workflow_provider.icon)
icon = json.loads(workflow_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str):
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
@ -910,13 +914,13 @@ class ToolManager:
if api_provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
icon: dict = json.loads(api_provider.icon)
icon = json.loads(api_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
@ -937,7 +941,7 @@ class ToolManager:
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> Union[str, dict[str, Any]]:
) -> str | Mapping[str, str]:
"""
get the tool icon
@ -962,11 +966,10 @@ class ToolManager:
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
elif provider_type == ToolProviderType.PLUGIN:
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
if isinstance(provider, PluginToolProviderController):
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
raise ValueError(f"plugin provider {provider_id} not found")
elif provider_type == ToolProviderType.MCP:
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
@ -977,7 +980,7 @@ class ToolManager:
def _convert_tool_parameters_type(
cls,
parameters: list[ToolParameter],
variable_pool: VariablePool | None,
variable_pool: Optional["VariablePool"],
tool_configurations: dict[str, Any],
typ: Literal["agent", "workflow", "tool"] = "workflow",
) -> dict[str, Any]:

View File

@ -39,14 +39,12 @@ class WorkflowTool(Tool):
entity: ToolEntity,
runtime: ToolRuntime,
label: str = "Workflow",
thread_pool_id: str | None = None,
):
self.workflow_app_id = workflow_app_id
self.workflow_as_tool_id = workflow_as_tool_id
self.version = version
self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth
self.thread_pool_id = thread_pool_id
self.label = label
super().__init__(entity=entity, runtime=runtime)
@ -90,7 +88,6 @@ class WorkflowTool(Tool):
invoke_from=self.runtime.invoke_from,
streaming=False,
call_depth=self.workflow_call_depth + 1,
workflow_thread_pool_id=self.thread_pool_id,
)
assert isinstance(result, dict)
data = result.get("data", {})

View File

@ -130,7 +130,7 @@ class ArraySegment(Segment):
def markdown(self) -> str:
items = []
for item in self.value:
items.append(str(item))
items.append(f"- {item}")
return "\n".join(items)

132
api/core/workflow/README.md Normal file
View File

@ -0,0 +1,132 @@
# Workflow
## Project Overview
This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control.
## Architecture
### Core Components
The graph engine follows a layered architecture with strict dependency rules:
1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution
- **Manager** - External control interface for stop/pause/resume commands
- **Worker** - Node execution runtime
- **Command Processing** - Handles control commands (abort, pause, resume)
- **Event Management** - Event propagation and layer notifications
- **Graph Traversal** - Edge processing and skip propagation
- **Response Coordinator** - Path tracking and session management
- **Layers** - Pluggable middleware (debug logging, execution limits)
- **Command Channels** - Communication channels (InMemory, Redis)
1. **Graph** (`graph/`) - Graph structure and runtime state
- **Graph Template** - Workflow definition
- **Edge** - Node connections with conditions
- **Runtime State Protocol** - State management interface
1. **Nodes** (`nodes/`) - Node implementations
- **Base** - Abstract node classes and variable parsing
- **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc.
1. **Events** (`node_events/`) - Event system
- **Base** - Event protocols
- **Node Events** - Node lifecycle events
1. **Entities** (`entities/`) - Domain models
- **Variable Pool** - Variable storage
- **Graph Init Params** - Initialization configuration
## Key Design Patterns
### Command Channel Pattern
External workflow control via Redis or in-memory channels:
```python
# Send stop command to running workflow
channel = RedisChannel(redis_client, f"workflow:{task_id}:commands")
channel.send_command(AbortCommand(reason="User requested"))
```
### Layer System
Extensible middleware for cross-cutting concerns:
```python
engine = GraphEngine(graph)
engine.add_layer(DebugLoggingLayer(level="INFO"))
engine.add_layer(ExecutionLimitsLayer(max_nodes=100))
```
### Event-Driven Architecture
All node executions emit events for monitoring and integration:
- `NodeRunStartedEvent` - Node execution begins
- `NodeRunSucceededEvent` - Node completes successfully
- `NodeRunFailedEvent` - Node encounters error
- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle
### Variable Pool
Centralized variable storage with namespace isolation:
```python
# Variables scoped by node_id
pool.add(["node1", "output"], value)
result = pool.get(["node1", "output"])
```
## Import Architecture Rules
The codebase enforces strict layering via import-linter:
1. **Workflow Layers** (top to bottom):
- graph_engine → graph_events → graph → nodes → node_events → entities
1. **Graph Engine Internal Layers**:
- orchestration → command_processing → event_management → graph_traversal → domain
1. **Domain Isolation**:
- Domain models cannot import from infrastructure layers
1. **Command Channel Independence**:
- InMemory and Redis channels must remain independent
## Common Tasks
### Adding a New Node Type
1. Create node class in `nodes/<node_type>/`
1. Inherit from `BaseNode` or appropriate base class
1. Implement `_run()` method
1. Register in `nodes/node_mapping.py`
1. Add tests in `tests/unit_tests/core/workflow/nodes/`
### Implementing a Custom Layer
1. Create class inheriting from `Layer` base
1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()`
1. Add to engine via `engine.add_layer()`
### Debugging Workflow Execution
Enable debug logging layer:
```python
debug_layer = DebugLoggingLayer(
level="DEBUG",
include_inputs=True,
include_outputs=True
)
```

View File

@ -1,7 +0,0 @@
from .base_workflow_callback import WorkflowCallback
from .workflow_logging_callback import WorkflowLoggingCallback
__all__ = [
"WorkflowCallback",
"WorkflowLoggingCallback",
]

View File

@ -1,12 +0,0 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.event import GraphEngineEvent
class WorkflowCallback(ABC):
@abstractmethod
def on_event(self, event: GraphEngineEvent):
"""
Published event
"""
raise NotImplementedError

View File

@ -1,259 +0,0 @@
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from .base_workflow_callback import WorkflowCallback
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self):
self.current_node_id: str | None = None
def on_event(self, event: GraphEngineEvent):
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunPartialSucceededEvent):
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(event=event)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(event=event)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(event=event)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(event=event)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(event=event)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(event=event)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(event=event)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(event=event)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(event=event)
elif isinstance(event, LoopRunStartedEvent):
self.on_workflow_loop_started(event=event)
elif isinstance(event, LoopRunNextEvent):
self.on_workflow_loop_next(event=event)
elif isinstance(event, LoopRunSucceededEvent | LoopRunFailedEvent):
self.on_workflow_loop_completed(event=event)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent):
"""
Workflow node execute started
"""
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
self.print_text(f"Node ID: {event.node_id}", color="yellow")
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
self.print_text(f"Type: {event.node_type.value}", color="yellow")
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent):
"""
Workflow node execute succeeded
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunSucceededEvent]", color="green")
self.print_text(f"Node ID: {event.node_id}", color="green")
self.print_text(f"Node Title: {event.node_data.title}", color="green")
self.print_text(f"Type: {event.node_type.value}", color="green")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="green",
)
self.print_text(
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="green",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="green",
)
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color="green",
)
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent):
"""
Workflow node execute failed
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunFailedEvent]", color="red")
self.print_text(f"Node ID: {event.node_id}", color="red")
self.print_text(f"Node Title: {event.node_data.title}", color="red")
self.print_text(f"Type: {event.node_type.value}", color="red")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color="red")
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="red",
)
self.print_text(
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="red",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="red",
)
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent):
"""
Publish text chunk
"""
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text("\n[NodeRunStreamChunkEvent]")
self.print_text(f"Node ID: {route_node_state.node_id}")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
)
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent):
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
if event.in_loop_id:
self.print_text(f"Loop ID: {event.in_loop_id}", color="blue")
def on_workflow_parallel_completed(self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = "blue"
elif isinstance(event, ParallelBranchRunFailedEvent):
color = "red"
self.print_text(
"\n[ParallelBranchRunSucceededEvent]"
if isinstance(event, ParallelBranchRunSucceededEvent)
else "\n[ParallelBranchRunFailedEvent]",
color=color,
)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if event.in_loop_id:
self.print_text(f"Loop ID: {event.in_loop_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(self, event: IterationRunStartedEvent):
"""
Publish iteration started
"""
self.print_text("\n[IterationRunStartedEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
def on_workflow_iteration_next(self, event: IterationRunNextEvent):
"""
Publish iteration next
"""
self.print_text("\n[IterationRunNextEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
self.print_text(f"Iteration Index: {event.index}", color="blue")
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent):
"""
Publish iteration completed
"""
self.print_text(
"\n[IterationRunSucceededEvent]"
if isinstance(event, IterationRunSucceededEvent)
else "\n[IterationRunFailedEvent]",
color="blue",
)
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
def on_workflow_loop_started(self, event: LoopRunStartedEvent):
"""
Publish loop started
"""
self.print_text("\n[LoopRunStartedEvent]", color="blue")
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def on_workflow_loop_next(self, event: LoopRunNextEvent):
"""
Publish loop next
"""
self.print_text("\n[LoopRunNextEvent]", color="blue")
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
self.print_text(f"Loop Index: {event.index}", color="blue")
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent):
"""
Publish loop completed
"""
self.print_text(
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
color="blue",
)
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def print_text(self, text: str, color: str | None = None, end: str = "\n"):
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(f"{text_to_print}", end=end)
def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"

View File

@ -0,0 +1,18 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .run_condition import RunCondition
from .variable_pool import VariablePool, VariableValue
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"RunCondition",
"VariablePool",
"VariableValue",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -0,0 +1,8 @@
from pydantic import BaseModel
class AgentNodeStrategyInit(BaseModel):
"""Agent node strategy initialization data."""
name: str
icon: str | None = None

View File

@ -3,19 +3,18 @@ from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from models.enums import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
user_from: str = Field(
..., description="user from, account or end-user"
) # Should be UserFrom enum: 'account' | 'end-user'
invoke_from: str = Field(
..., description="invoke from, service-api, web-app, explore or debugger"
) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger'
call_depth: int = Field(..., description="call depth")

View File

@ -0,0 +1,160 @@
from copy import deepcopy
from pydantic import BaseModel, PrivateAttr
from core.model_runtime.entities.llm_entities import LLMUsage
from .variable_pool import VariablePool
class GraphRuntimeState(BaseModel):
# Private attributes to prevent direct modification
_variable_pool: VariablePool = PrivateAttr()
_start_at: float = PrivateAttr()
_total_tokens: int = PrivateAttr(default=0)
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue_json: str = PrivateAttr()
_graph_execution_json: str = PrivateAttr()
_response_coordinator_json: str = PrivateAttr()
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue_json: str = "",
graph_execution_json: str = "",
response_coordinator_json: str = "",
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
super().__init__(**kwargs)
# Initialize private attributes with validation
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
if llm_usage is None:
llm_usage = LLMUsage.empty_usage()
self._llm_usage = llm_usage
if outputs is None:
outputs = {}
self._outputs = deepcopy(outputs)
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._ready_queue_json = ready_queue_json
self._graph_execution_json = graph_execution_json
self._response_coordinator_json = response_coordinator_json
@property
def variable_pool(self) -> VariablePool:
"""Get the variable pool."""
return self._variable_pool
@property
def start_at(self) -> float:
"""Get the start time."""
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
"""Set the start time."""
self._start_at = value
@property
def total_tokens(self) -> int:
"""Get the total tokens count."""
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int):
"""Set the total tokens count."""
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
"""Get the LLM usage info."""
# Return a copy to prevent external modification
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage):
"""Set the LLM usage info."""
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, object]:
"""Get a copy of the outputs dictionary."""
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, object]) -> None:
"""Set the outputs dictionary."""
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
"""Set a single output value."""
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
"""Get a single output value."""
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
"""Update multiple output values."""
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count."""
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
"""Set the node run steps count."""
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
"""Increment the node run steps by 1."""
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
"""Add tokens to the total count."""
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
@property
def ready_queue_json(self) -> str:
"""Get a copy of the ready queue state."""
return self._ready_queue_json
@property
def graph_execution_json(self) -> str:
"""Get a copy of the serialized graph execution state."""
return self._graph_execution_json
@property
def response_coordinator_json(self) -> str:
"""Get a copy of the serialized response coordinator state."""
return self._response_coordinator_json

View File

@ -1,34 +0,0 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Mapping[str, Any] | None = None # node inputs
process_data: Mapping[str, Any] | None = None # process data
outputs: Mapping[str, Any] | None = None # node outputs
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # node metadata
llm_usage: LLMUsage | None = None # llm usage
edge_source_handle: str | None = None # source handle id of node with multiple branches
error: str | None = None # error message if status is failed
error_type: str | None = None # error type if status is failed
# single step node run retry
retry_index: int = 0
class AgentNodeStrategyInit(BaseModel):
name: str
icon: str | None = None

View File

@ -1,12 +0,0 @@
from collections.abc import Sequence
from pydantic import BaseModel
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: Sequence[str]

View File

@ -14,7 +14,7 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_V
from core.workflow.system_variable import SystemVariable
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, File]
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@ -40,11 +40,11 @@ class VariablePool(BaseModel):
)
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list,
default_factory=list[VariableUnion],
)
conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list,
default_factory=list[VariableUnion],
)
def model_post_init(self, context: Any, /):
@ -191,7 +191,7 @@ class VariablePool(BaseModel):
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
segments = []
segments: list[Segment] = []
for part in filter(lambda x: x, parts):
if "." in part and (variable := self.get(part.split("."))):
segments.append(variable)

View File

@ -7,31 +7,14 @@ implementation details like tenant_id, app_id, etc.
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from libs.datetime_utils import naive_utc_now
class WorkflowType(StrEnum):
"""
Workflow Type Enum for domain layer
"""
WORKFLOW = "workflow"
CHAT = "chat"
class WorkflowExecutionStatus(StrEnum):
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"
class WorkflowExecution(BaseModel):
"""
Domain model for workflow execution based on WorkflowRun but without

View File

@ -8,49 +8,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from core.workflow.nodes.enums import NodeType
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
class WorkflowNodeExecutionStatus(StrEnum):
"""
Node Execution Status Enum.
"""
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
RETRY = "retry"
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class WorkflowNodeExecution(BaseModel):

View File

@ -1,4 +1,12 @@
from enum import StrEnum
from enum import Enum, StrEnum
class NodeState(Enum):
"""State of a node or edge during workflow execution."""
UNKNOWN = "unknown"
TAKEN = "taken"
SKIPPED = "skipped"
class SystemVariableKey(StrEnum):
@ -14,3 +22,104 @@ class SystemVariableKey(StrEnum):
APP_ID = "app_id"
WORKFLOW_ID = "workflow_id"
WORKFLOW_EXECUTION_ID = "workflow_run_id"
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
class NodeExecutionType(StrEnum):
"""Node execution type classification."""
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
ROOT = "root" # Nodes that can serve as execution entry points
class ErrorStrategy(StrEnum):
FAIL_BRANCH = "fail-branch"
DEFAULT_VALUE = "default-value"
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"
class WorkflowType(StrEnum):
"""
Workflow Type Enum for domain layer
"""
WORKFLOW = "workflow"
CHAT = "chat"
class WorkflowExecutionStatus(StrEnum):
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
class WorkflowNodeExecutionStatus(StrEnum):
PENDING = "pending" # Node is scheduled but not yet executing
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
STOPPED = "stopped"
PAUSED = "paused"
# Legacy statuses - kept for backward compatibility
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling

View File

@ -1,8 +1,8 @@
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.node import Node
class WorkflowNodeRunFailedError(Exception):
def __init__(self, node: BaseNode, err_msg: str):
def __init__(self, node: Node, err_msg: str):
self.node = node
self.error = err_msg
super().__init__(f"Node {node.title} run failed: {err_msg}")

View File

@ -0,0 +1,16 @@
from .edge import Edge
from .graph import Graph, NodeFactory
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .graph_template import GraphTemplate
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
__all__ = [
"Edge",
"Graph",
"GraphTemplate",
"NodeFactory",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
]

View File

@ -0,0 +1,15 @@
import uuid
from dataclasses import dataclass, field
from core.workflow.enums import NodeState
@dataclass
class Edge:
"""Edge connecting two nodes in a workflow graph."""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
tail: str = "" # tail node id (source)
head: str = "" # head node id (target)
source_handle: str = "source" # source handle for conditional branching
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state

View File

@ -0,0 +1,346 @@
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .edge import Edge
logger = logging.getLogger(__name__)
class NodeFactory(Protocol):
"""
Protocol for creating Node instances from node data dictionaries.
This protocol decouples the Graph class from specific node mapping implementations,
allowing for different node creation strategies while maintaining type safety.
"""
def create_node(self, node_config: dict[str, object]) -> Node:
"""
Create a Node instance from node configuration data.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
"""
...
@final
class Graph:
"""Graph representation with nodes and edges for workflow execution."""
def __init__(
self,
*,
nodes: dict[str, Node] | None = None,
edges: dict[str, Edge] | None = None,
in_edges: dict[str, list[str]] | None = None,
out_edges: dict[str, list[str]] | None = None,
root_node: Node,
):
"""
Initialize Graph instance.
:param nodes: graph nodes mapping (node id: node object)
:param edges: graph edges mapping (edge id: edge object)
:param in_edges: incoming edges mapping (node id: list of edge ids)
:param out_edges: outgoing edges mapping (node id: list of edge ids)
:param root_node: root node object
"""
self.nodes = nodes or {}
self.edges = edges or {}
self.in_edges = in_edges or {}
self.out_edges = out_edges or {}
self.root_node = root_node
@classmethod
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
"""
Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config
"""
node_configs_map: dict[str, dict[str, object]] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id or not isinstance(node_id, str):
continue
node_configs_map[node_id] = node_config
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
node_configs_map: Mapping[str, Mapping[str, object]],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
"""
Find the root node ID if not specified.
:param node_configs_map: mapping of node ID to node config
:param edge_configs: list of edge configurations
:param root_node_id: explicitly specified root node ID
:return: determined root node ID
"""
if root_node_id:
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
return root_node_id
# Find nodes with no incoming edges
nodes_with_incoming: set[str] = set()
for edge_config in edge_configs:
target = edge_config.get("target")
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid].get("data")
if not is_str_dict(node_data):
continue
node_type = node_data.get("type")
if not isinstance(node_type, str):
continue
if node_type == NodeType.START:
start_node_id = nid
break
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
if not root_node_id:
raise ValueError("Unable to determine root node ID")
return root_node_id
@classmethod
def _build_edges(
cls, edge_configs: list[dict[str, object]]
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
"""
Build edge objects and mappings from edge configurations.
:param edge_configs: list of edge configurations
:return: tuple of (edges dict, in_edges dict, out_edges dict)
"""
edges: dict[str, Edge] = {}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
edge_counter = 0
for edge_config in edge_configs:
source = edge_config.get("source")
target = edge_config.get("target")
if not is_str(source) or not is_str(target):
continue
# Create edge
edge_id = f"edge_{edge_counter}"
edge_counter += 1
source_handle = edge_config.get("sourceHandle", "source")
if not is_str(source_handle):
continue
edge = Edge(
id=edge_id,
tail=source,
head=target,
source_handle=source_handle,
)
edges[edge_id] = edge
out_edges[source].append(edge_id)
in_edges[target].append(edge_id)
return edges, dict(in_edges), dict(out_edges)
@classmethod
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
node_factory: "NodeFactory",
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
:param node_configs_map: mapping of node ID to node config
:param node_factory: factory for creating node instances
:return: mapping of node ID to node instance
"""
nodes: dict[str, Node] = {}
for node_id, node_config in node_configs_map.items():
try:
node_instance = node_factory.create_node(node_config)
except Exception:
logger.exception("Failed to create node instance for node_id %s", node_id)
raise
nodes[node_id] = node_instance
return nodes
@classmethod
def _mark_inactive_root_branches(
cls,
nodes: dict[str, Node],
edges: dict[str, Edge],
in_edges: dict[str, list[str]],
out_edges: dict[str, list[str]],
active_root_id: str,
) -> None:
"""
Mark nodes and edges from inactive root branches as skipped.
Algorithm:
1. Mark inactive root nodes as skipped
2. For skipped nodes, mark all their outgoing edges as skipped
3. For each edge marked as skipped, check its target node:
- If ALL incoming edges are skipped, mark the node as skipped
- Otherwise, leave the node state unchanged
:param nodes: mapping of node ID to node instance
:param edges: mapping of edge ID to edge instance
:param in_edges: mapping of node ID to incoming edge IDs
:param out_edges: mapping of node ID to outgoing edge IDs
:param active_root_id: ID of the active root node
"""
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
top_level_roots: list[str] = [
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
]
# If there's only one root or the active root is not a top-level root, no marking needed
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
return
# Mark inactive root nodes as skipped
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
for root_id in inactive_roots:
if root_id in nodes:
nodes[root_id].state = NodeState.SKIPPED
# Recursively mark downstream nodes and edges
def mark_downstream(node_id: str) -> None:
"""Recursively mark downstream nodes and edges as skipped."""
if nodes[node_id].state != NodeState.SKIPPED:
return
# If this node is skipped, mark all its outgoing edges as skipped
out_edge_ids = out_edges.get(node_id, [])
for edge_id in out_edge_ids:
edge = edges[edge_id]
edge.state = NodeState.SKIPPED
# Check the target node of this edge
target_node = nodes[edge.head]
in_edge_ids = in_edges.get(target_node.id, [])
in_edge_states = [edges[eid].state for eid in in_edge_ids]
# If all incoming edges are skipped, mark the node as skipped
if all(state == NodeState.SKIPPED for state in in_edge_states):
target_node.state = NodeState.SKIPPED
# Recursively process downstream nodes
mark_downstream(target_node.id)
# Process each inactive root and its downstream nodes
for root_id in inactive_roots:
mark_downstream(root_id)
@classmethod
def init(
cls,
*,
graph_config: Mapping[str, object],
node_factory: "NodeFactory",
root_node_id: str | None = None,
) -> "Graph":
"""
Initialize graph
:param graph_config: graph config containing nodes and edges
:param node_factory: factory for creating node instances from config data
:param root_node_id: root node id
:return: graph instance
"""
# Parse configs
edge_configs = graph_config.get("edges", [])
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)
# Find root node
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
# Build edges
edges, in_edges, out_edges = cls._build_edges(edge_configs)
# Create node instances
nodes = cls._create_node_instances(node_configs_map, node_factory)
# Get root node instance
root_node = nodes[root_node_id]
# Mark inactive root branches as skipped
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph
return cls(
nodes=nodes,
edges=edges,
in_edges=in_edges,
out_edges=out_edges,
root_node=root_node,
)
@property
def node_ids(self) -> list[str]:
"""
Get list of node IDs (compatibility property for existing code)
:return: list of node IDs
"""
return list(self.nodes.keys())
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
"""
Get all outgoing edges from a node (V2 method)
:param node_id: node id
:return: list of outgoing edges
"""
edge_ids = self.out_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
def get_incoming_edges(self, node_id: str) -> list[Edge]:
"""
Get all incoming edges to a node (V2 method)
:param node_id: node id
:return: list of incoming edges
"""
edge_ids = self.in_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]

View File

@ -0,0 +1,61 @@
from collections.abc import Mapping
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (read-only)."""
...
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (read-only)."""
...
class ReadOnlyGraphRuntimeState(Protocol):
"""
Read-only view of GraphRuntimeState for layers.
This protocol defines a read-only interface that prevents layers from
modifying the graph runtime state while still allowing observation.
All methods return defensive copies to ensure immutability.
"""
@property
def variable_pool(self) -> ReadOnlyVariablePool:
"""Get read-only access to the variable pool."""
...
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
...
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
...
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
...
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
...
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
...
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
...

View File

@ -0,0 +1,20 @@
from typing import Any
from pydantic import BaseModel, Field
class GraphTemplate(BaseModel):
"""
Graph Template for container nodes and subgraph expansion
According to GraphEngine V2 spec, GraphTemplate contains:
- nodes: mapping of node definitions
- edges: mapping of edge definitions
- root_ids: list of root node IDs
- output_selectors: list of output selectors for the template
"""
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
output_selectors: list[str] = Field(default_factory=list, description="output selectors")

View File

@ -0,0 +1,77 @@
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
class ReadOnlyVariablePoolWrapper:
"""Wrapper that provides read-only access to VariablePool."""
def __init__(self, variable_pool: VariablePool):
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (returns a defensive copy)."""
value = self._variable_pool.get([node_id, variable_key])
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (returns defensive copies)."""
variables: dict[str, object] = {}
if node_id in self._variable_pool.variable_dictionary:
for key, var in self._variable_pool.variable_dictionary[node_id].items():
# Variables have a value property that contains the actual data
variables[key] = deepcopy(var.value)
return variables
class ReadOnlyGraphRuntimeStateWrapper:
"""
Wrapper that provides read-only access to GraphRuntimeState.
This wrapper ensures that layers can observe the state without
modifying it. All returned values are defensive copies.
"""
def __init__(self, state: GraphRuntimeState):
self._state = state
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
@property
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
"""Get read-only access to the variable pool."""
return self._variable_pool_wrapper
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
return self._state.start_at
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
return self._state.total_tokens
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
# Return a copy to prevent modification
return self._state.llm_usage.model_copy()
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
return deepcopy(self._state.outputs)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
return self._state.node_run_steps
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
return self._state.get_output(key, default)

View File

@ -1,4 +1,3 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
from .graph_engine import GraphEngine
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
__all__ = ["GraphEngine"]

View File

@ -0,0 +1,33 @@
# Command Channels
Channel implementations for external workflow control.
## Components
### InMemoryChannel
Thread-safe in-memory queue for single-process deployments.
- `fetch_commands()` - Get pending commands
- `send_command()` - Add command to queue
### RedisChannel
Redis-based queue for distributed deployments.
- `fetch_commands()` - Get commands with JSON deserialization
- `send_command()` - Store commands with TTL
## Usage
```python
# Local execution
channel = InMemoryChannel()
channel.send_command(AbortCommand(graph_id="workflow-123"))
# Distributed execution
redis_channel = RedisChannel(
redis_client=redis_client,
channel_key="workflow:123:commands"
)
```

View File

@ -0,0 +1,6 @@
"""Command channel implementations for GraphEngine."""
from .in_memory_channel import InMemoryChannel
from .redis_channel import RedisChannel
__all__ = ["InMemoryChannel", "RedisChannel"]

View File

@ -0,0 +1,53 @@
"""
In-memory implementation of CommandChannel for local/testing scenarios.
This implementation uses a thread-safe queue for command communication
within a single process. Each instance handles commands for one workflow execution.
"""
from queue import Queue
from typing import final
from ..entities.commands import GraphEngineCommand
@final
class InMemoryChannel:
"""
In-memory command channel implementation using a thread-safe queue.
Each instance is dedicated to a single GraphEngine/workflow execution.
Suitable for local development, testing, and single-instance deployments.
"""
def __init__(self) -> None:
"""Initialize the in-memory channel with a single queue."""
self._queue: Queue[GraphEngineCommand] = Queue()
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from the queue.
Returns:
List of pending commands (drains the queue)
"""
commands: list[GraphEngineCommand] = []
# Drain all available commands from the queue
while not self._queue.empty():
try:
command = self._queue.get_nowait()
commands.append(command)
except Exception:
break
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to this channel's queue.
Args:
command: The command to send
"""
self._queue.put(command)

View File

@ -0,0 +1,114 @@
"""
Redis-based implementation of CommandChannel for distributed scenarios.
This implementation uses Redis lists for command queuing, supporting
multi-instance deployments and cross-server communication.
Each instance uses a unique key for its command queue.
"""
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@final
class RedisChannel:
"""
Redis-based command channel implementation for distributed systems.
Each instance uses a unique Redis key for its command queue.
Commands are JSON-serialized for transport.
"""
def __init__(
self,
redis_client: "RedisClientWrapper",
channel_key: str,
command_ttl: int = 3600,
) -> None:
"""
Initialize the Redis channel.
Args:
redis_client: Redis client instance
channel_key: Unique key for this channel's command queue
command_ttl: TTL for command keys in seconds (default: 3600)
"""
self._redis = redis_client
self._key = channel_key
self._command_ttl = command_ttl
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from Redis.
Returns:
List of pending commands (drains the Redis list)
"""
commands: list[GraphEngineCommand] = []
# Use pipeline for atomic operations
with self._redis.pipeline() as pipe:
# Get all commands and clear the list atomically
pipe.lrange(self._key, 0, -1)
pipe.delete(self._key)
results = pipe.execute()
# Parse commands from JSON
if results[0]:
for command_json in results[0]:
try:
command_data = json.loads(command_json)
command = self._deserialize_command(command_data)
if command:
commands.append(command)
except (json.JSONDecodeError, ValueError):
# Skip invalid commands
continue
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to Redis.
Args:
command: The command to send
"""
command_json = json.dumps(command.model_dump())
# Push to list and set expiry
with self._redis.pipeline() as pipe:
pipe.rpush(self._key, command_json)
pipe.expire(self._key, self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
"""
Deserialize a command from dictionary data.
Args:
data: Command data dictionary
Returns:
Deserialized command or None if invalid
"""
command_type_value = data.get("command_type")
if not isinstance(command_type_value, str):
return None
try:
command_type = CommandType(command_type_value)
if command_type == CommandType.ABORT:
return AbortCommand(**data)
else:
# For other command types, use base class
return GraphEngineCommand(**data)
except (ValueError, TypeError):
return None

View File

@ -0,0 +1,14 @@
"""
Command processing subsystem for graph engine.
This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
]

View File

@ -0,0 +1,32 @@
"""
Command handler implementations.
"""
import logging
from typing import final
from typing_extensions import override
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@final
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
"""
Handle an abort command.
Args:
command: The abort command
execution: Graph execution to abort
"""
assert isinstance(command, AbortCommand)
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
execution.abort(command.reason or "User requested abort")

View File

@ -0,0 +1,79 @@
"""
Main command processor for handling external commands.
"""
import logging
from typing import Protocol, final
from ..domain.graph_execution import GraphExecution
from ..entities.commands import GraphEngineCommand
from ..protocols.command_channel import CommandChannel
logger = logging.getLogger(__name__)
class CommandHandler(Protocol):
"""Protocol for command handlers."""
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
@final
class CommandProcessor:
"""
Processes external commands sent to the engine.
This polls the command channel and dispatches commands to
appropriate handlers.
"""
def __init__(
self,
command_channel: CommandChannel,
graph_execution: GraphExecution,
) -> None:
"""
Initialize the command processor.
Args:
command_channel: Channel for receiving commands
graph_execution: Graph execution aggregate
"""
self._command_channel = command_channel
self._graph_execution = graph_execution
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
"""
Register a handler for a command type.
Args:
command_type: Type of command to handle
handler: Handler for the command
"""
self._handlers[command_type] = handler
def process_commands(self) -> None:
"""Check for and process any pending commands."""
try:
commands = self._command_channel.fetch_commands()
for command in commands:
self._handle_command(command)
except Exception as e:
logger.warning("Error processing commands: %s", e)
def _handle_command(self, command: GraphEngineCommand) -> None:
"""
Handle a single command.
Args:
command: The command to handle
"""
handler = self._handlers.get(type(command))
if handler:
try:
handler.handle(command, self._graph_execution)
except Exception:
logger.exception("Error handling command %s", command.__class__.__name__)
else:
logger.warning("No handler registered for command: %s", command.__class__.__name__)

View File

@ -1,25 +0,0 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class RunConditionHandler(ABC):
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
raise NotImplementedError

Some files were not shown because too many files have changed in this diff Show More