Compare commits

..

130 Commits

Author SHA1 Message Date
dbceb3067e refactor(api): migrate console tag responses from marshal_with to BaseModel (#35208)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-15 09:57:27 +00:00
425457cb16 test: remove legacy workflow draft variable api test (#35226) 2026-04-15 09:53:53 +00:00
e5bd18132c test: remove legacy jinja2 code executor integration test (#35222) 2026-04-15 09:53:46 +00:00
2f33867d07 test: remove legacy trigger provider permissions test (#35227) 2026-04-15 09:53:42 +00:00
fd71c56f16 test: remove legacy storage key loader integration test (#35225) 2026-04-15 09:52:31 +00:00
e3c2116501 fix: remove enable for get (#35245)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
2026-04-15 09:18:29 +00:00
fb17339d89 feat(web): unify create_app tracking and persist external attribution (#35241)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-04-15 08:59:31 +00:00
9fd196642d feat: tidb endpoint (#35158)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-15 08:46:25 +00:00
98897a5379 test: migrate webhook service additional mock tests to testcontainers (#35199)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-15 08:14:36 +00:00
5542329554 fix(dataset): fix dataset list overlay issue (#35244)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-04-15 08:03:02 +00:00
79332c0e5e fix: Change 'commit' to 'flush' to prevent subsequent transaction fai… (#35243) 2026-04-15 07:50:52 +00:00
yyh
50a55513d4 refactor(ui): decouple CSS dependencies and improve test quality (#35242)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-15 07:03:40 +00:00
3bccdd6c9a test: migrate Service API site controller tests to Testcontainers (#32454) (#35183)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-15 05:51:34 +00:00
76af80e332 fix: open restore version panel raise 500 (#35240) 2026-04-15 04:43:38 +00:00
7a880ae60c fix: import DSL and copy app not work (#35239)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-15 04:43:22 +00:00
5bc0f9513b test: remove legacy python3 code executor integration test (#35223) 2026-04-15 02:20:06 +00:00
b77801ece9 test: remove legacy code executor integration test (#35224) 2026-04-15 02:19:42 +00:00
7de92c598f test: migrate schedule service mock tests to testcontainers (#35196)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 20:34:53 +00:00
693080aa12 test: migrate dataset service dataset mock tests to testcontainers (#35194)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 19:52:31 +00:00
25c388d0db refactor(api): migrate console workflow-trigger responses to BaseModel (#35200)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 19:52:17 +00:00
b1722c8af9 refactor(api): migrate console conversation variables response model to BaseModel (#35193)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 19:51:33 +00:00
b65a5fcd97 refactor(api): migrate service api workflow responses from marshal_with to BaseModel (#35195)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 19:50:59 +00:00
1c3cba281a refactor(api): migrate console message responses from marshal_with to BaseModel (#35204)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 19:49:33 +00:00
800954f8ce refactor(api): migrate service conversation-variable responses to BaseModel (#35205)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 19:49:21 +00:00
f66a3c49c4 refactor(api): migrate console recommended-app response to BaseModel (#35206)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-14 19:48:43 +00:00
ef396ac84e refactor(api): migrate workspace current response from marshal_with to BaseModel (#35207)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 19:48:09 +00:00
7e7b27fdec refactor: replace bare dict with dict[str, Any] in response converter… (#35212)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-14 19:45:04 +00:00
9c90c1c455 refactor: replace bare dict with dict[str, Any] in services and hosti… (#35211)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 19:44:40 +00:00
b1df52b8ff refactor(api): migrate console workflow app-log responses to BaseModel (#35201)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 18:43:09 +00:00
e527b7c5f1 refactor(api): migrate console installed-app list response to BaseModel (#35202)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 18:40:40 +00:00
149b9d4c0f refactor: replace bare dict with dict[str, Any] in services unit test helpers (#35182) 2026-04-14 18:37:59 +00:00
ef28a63ad3 refactor(api): add null safety to extractor_processor and firecrawl (#35209)
Co-authored-by: tmimmanuel <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>
2026-04-14 18:23:20 +00:00
e78558bc06 refactor(api): migrate dataset hit-testing response model to BaseModel (#35192)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 18:12:40 +00:00
f63d7c4121 test: remove document service status mock tests superseded by testcontainers (#35197) 2026-04-14 18:07:00 +00:00
ef062fb397 refactor(api): migrate console extension endpoint from api.model to BaseModel (#35189)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 17:58:33 +00:00
a2ea7ca039 refactor(api): migrate workspace account marshal_with responses to BaseModel (#35190)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-14 17:57:34 +00:00
6876cd787b test: migrate web site controller tests to Testcontainers (#32454) (#35180)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 17:52:13 +00:00
50a6892c3a refactor: replace bare dict with dict[str, Any] in controller and core unit tests (#35181) 2026-04-14 17:51:49 +00:00
1bcc7f78c7 refactor: replace bare dict with dict[str, Any] in models, providers, extensions, libs, and test utilities (#35186) 2026-04-14 17:51:25 +00:00
2fd5b76ac1 refactor: replace bare dict with dict[str, Any] in enterprise telemetry, external data tool, and moderation tests (#35185)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 17:51:02 +00:00
62f42b3f24 refactor: replace bare dict with dict[str, Any] in RAG and service unit tests (#35184)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 17:50:43 +00:00
2c58b424a1 refactor(web): migrate confirm dialogs to base/ui/alert-dialog (#35127)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-04-14 14:46:26 +00:00
381c518b23 test: migrate conversation read timestamp SQL test to Testcontainers (#32454) (#35177) 2026-04-14 14:19:19 +00:00
yyh
ebf741114d refactor(web): replace Button destructive boolean with tone semantic axis (#35176) 2026-04-14 14:16:39 +00:00
648dde5e96 ci: Fix path in coverage markdown rendering step (#35136) 2026-04-14 13:49:03 +00:00
a3042e6332 test: migrate clean_notion_document integration tests to SQLAlchemy 2… (#35147)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 13:31:42 +00:00
e5fd3133f4 test: migrate task integration tests to SQLAlchemy 2.0 query APIs (#35170)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 13:27:39 +00:00
yyh
e1bbe57f9c refactor(web): re-design button api (#35166)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 13:22:23 +00:00
d4783e8c14 chore: url in tool description support clicking jump directly (#35163) 2026-04-14 09:55:55 +00:00
736880e046 feat: support configurable redis key prefix (#35139) 2026-04-14 09:31:41 +00:00
bd7a9b5fcf refactor: replace bare dict with dict[str, Any] in model provider service and core modules (#35122)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-14 09:18:30 +00:00
9a47bb2f80 fix: doc modal hidden by config modal (#35157) 2026-04-14 08:16:19 +00:00
d7ad2baf79 chore: clarify tracing error copy to direct users to the Tracing tab (#35153) 2026-04-14 08:15:07 +00:00
a951cc996b test: migrate document indexing task tests to SQLAlchemy 2.0 select API (#35145)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 07:56:11 +00:00
173e0d6f35 test: migrate clean_dataset integration tests to SQLAlchemy 2.0 APIs (#35146)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 07:56:07 +00:00
62bb830338 refactor: convert InvokeFrom if/elif to match/case (#35143) 2026-04-14 07:46:58 +00:00
f7c6270f74 refactor: use sessionmaker in tool_label_manager.py (#34895)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 07:23:29 +00:00
711fe6ba2c refactor: convert plugin permission if/elif to match/case (#30001) (#35140) 2026-04-14 07:03:53 +00:00
fbedb60371 refactor: replace bare dict with typed annotations in core rag module (#35097) 2026-04-14 06:16:16 +00:00
974d2f1627 refactor: replace bare dict with typed annotations in llm_generator and prompt (#35100) 2026-04-14 06:15:52 +00:00
ed401728eb refactor: replace bare dict with typed annotations in app_config/extension/provider (#35099) 2026-04-14 06:11:00 +00:00
fc389a54c5 refactor: replace bare dict with typed annotations in core tools module (#35098) 2026-04-14 06:09:55 +00:00
c8b372dba0 chore(deps): update pyarrow to version 23.0.1 and add override deps (#35137) 2026-04-14 06:02:43 +00:00
2333d75c56 chore: allow disabling app-level PostgreSQL timezone injection (#35129) 2026-04-14 05:57:27 +00:00
2ef9a8a769 chore(i18n): sync translations with en-US (#35134)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2026-04-14 04:23:37 +00:00
yyh
21ab9b9d8c refactor(web): remove highPriority modal stacking (#35132) 2026-04-14 04:22:25 +00:00
yyh
79c1473378 refactor(web): align tooltip content class props (#35135) 2026-04-14 04:21:55 +00:00
93b8a74351 chore(deps): bump pillow from 12.1.1 to 12.2.0 in /api (#35119)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-14 03:56:14 +00:00
99
28185170b0 test: split merged API test modules and remove F811 ignore (#35105) 2026-04-14 03:54:30 +00:00
99
178883b4cc chore: remove unused Ruff ignore rules (#35102) 2026-04-14 03:53:20 +00:00
e9f9041b25 chore: Add global fetch mock in vitest.setup.ts to suppress happy-dom ECONNREFUSED errors (#35131)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-04-14 03:47:01 +00:00
175290fa04 feat(goto-anything): recent items, /go navigation command, deep app sub-sections (#35078)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-14 03:45:58 +00:00
b0c4d8c541 fix: Compatibility issues with the summary index feature when using the weaviate vector database (#35052)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-04-14 03:44:49 +00:00
0f643bca76 refactor: replace bare dict with dict[str, Any] in core tools and runtime (#35111) 2026-04-14 03:03:13 +00:00
eeebedcfe8 refactor: replace bare dict with dict[str, Any] in core provider services and misc modules (#35124) 2026-04-14 03:03:08 +00:00
2f682780fa refactor: replace bare dict with dict[str, Any] in rag_pipeline and datasource_provider services (#35107) 2026-04-14 03:02:41 +00:00
ed83f5369e refactor: replace bare dict with dict[str, Any] in entities, workflow nodes, and tasks (#35109) 2026-04-14 03:02:39 +00:00
4ee1bd5f32 refactor: replace bare dict with dict[str, Any] in VDB providers (#35110) 2026-04-14 03:02:36 +00:00
1c2bbed405 refactor: replace bare dict with dict[str, Any] in services grab-bag (#35112) 2026-04-14 03:02:34 +00:00
d573fc0e65 refactor: replace bare dict with dict[str, Any] in VDB providers and libs (#35123)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 03:02:29 +00:00
f8b249e649 fix(web): handle IME composition in DelimiterInput (#34660)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-14 02:49:37 +00:00
yyh
fbcab757d5 test(e2e): improve auth coverage and authoring support (#34920) 2026-04-14 02:22:34 +00:00
c0e998ef6e chore: update deps (#35066)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-04-14 02:19:29 +00:00
84f25807db test: migrate mail_human_input_delivery cleanup fixture to SQLAlchemy 2.0 delete API (#35090) 2026-04-13 19:31:11 +00:00
83b242be7b refactor: replace bare dict with typed annotations in core plugin module (#35096) 2026-04-13 19:23:21 +00:00
a12d740a5d test: migrate mail and segment indexing integration tests to SQLAlchemy 2.0 APIs (#35091) 2026-04-13 19:22:34 +00:00
3bbb014dc7 test: migrate remove_app_and_related_data integration tests to SQLAlchemy 2.0 APIs (#35092)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 19:22:07 +00:00
f040733e28 refactor(api): type _jsonify_form_definition payload with FormDefinitionPayload TypedDict (#35094)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 19:21:19 +00:00
b0bf7ca486 refactor: replace bare dict with typed annotations in controllers (#35095) 2026-04-13 19:19:52 +00:00
14d83c8bac test: migrate trigger integration tests to SQLAlchemy 2.0 select API (#35081) 2026-04-13 17:19:34 +00:00
8b506dfa42 refactor: replace bare dict with dict[str, Any] in openai_moderation (#35079) 2026-04-13 17:19:04 +00:00
ac2258c2dc refactor: replace bare dict with dict[str, Any] in app_config managers (#35087) 2026-04-13 17:14:39 +00:00
3c279edcf2 refactor: replace bare dict with dict[str, Any] in app task_entities … (#35084) 2026-04-13 17:14:23 +00:00
9ed8a5ed73 refactor: replace bare dict with dict[str, Any] in model_manager and … (#35083) 2026-04-13 17:14:08 +00:00
3d4ddf4a6f refactor: replace bare dict with dict[str, Any] in ops trace providers (#35082) 2026-04-13 17:13:46 +00:00
4e0273bb28 refactor: replace bare dict with dict[str, Any] in provider entities and plugin client (#35077) 2026-04-13 17:09:25 +00:00
7056d2ae99 refactor: replace bare dict with dict[str, Any] in moderation module (#35076) 2026-04-13 17:09:06 +00:00
d8fbc00cb9 refactor: replace bare dict with dict[str, Any] in dataset and external_knowledge services (#35073) 2026-04-13 17:08:45 +00:00
57c5f0ec87 refactor: replace bare dict with dict[str, Any] in tools manage services (#35075) 2026-04-13 17:08:31 +00:00
e5bd80c719 refactor: replace bare dict with dict[str, Any] in website_service (#35074) 2026-04-13 17:07:59 +00:00
Sam
25a33a454c fix: handle URL construction error when switching to Visual Editor (#35004)
Co-authored-by: Sami Rusani <sr@samirusani>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-13 15:11:16 +00:00
yyh
bd30784b1d chore(web): upgrade @base-ui/react to v1.4.0 (#35048) 2026-04-13 14:48:29 +00:00
28fce0a890 fix: click empty http node value may cause blur (#35051) 2026-04-13 14:48:18 +00:00
e1eb582bea refactor: replace bare dict with dict[str, Any] in pipeline_template module (#35071) 2026-04-13 14:05:59 +00:00
2042ee453b refactor: replace bare dict with dict[str, Any] in helper cache modules (#35067) 2026-04-13 14:05:50 +00:00
33c4e512f1 refactor: replace bare dict with dict[str, Any] in tools message_transformer (#35069) 2026-04-13 14:05:39 +00:00
253e8a3f98 refactor: replace bare dict with dict[str, Any] in ops_trace_manager (#35070) 2026-04-13 14:05:29 +00:00
06b63d65d1 refactor: replace bare dict with dict[str, Any] in rag extractors (#35068) 2026-04-13 14:05:21 +00:00
08f3133414 fix: db session expired issue (#35049) 2026-04-13 13:06:13 +00:00
d412cddf39 refactor: replace bare dict with UtmInfo TypedDict in operation_service (#35055) 2026-04-13 13:05:47 +00:00
671c5cdd84 refactor: replace bare dict with WorkflowRunListArgs TypedDict (#35057) 2026-04-13 13:05:39 +00:00
554f060092 refactor: replace bare dict with AdvancedPromptTemplateArgs TypedDict (#35056) 2026-04-13 13:05:23 +00:00
e243e8d8a3 refactor: replace bare dict with dict[str, Any] in datasource_entities (#35062) 2026-04-13 13:01:50 +00:00
1b935a367f refactor: replace bare dict with dict[str, Any] in watercrawl client (#35063) 2026-04-13 13:01:32 +00:00
2edd083a71 refactor: replace bare dict with dict[str, Any] in OpenAPI tools parser (#35061) 2026-04-13 13:01:21 +00:00
dd50a68bf2 refactor: replace bare dict with dict[str, Any] in ops_service tracin… (#35064) 2026-04-13 13:01:00 +00:00
e8dd3461e8 refactor: replace bare dict with dict[str, Any] in plugin endpoint_service (#35065) 2026-04-13 13:00:27 +00:00
8dd4473432 refactor(auth): standardize failed login audit logging (#35054) 2026-04-13 12:26:13 +00:00
b5bbbdd840 chore: revert react-i18next update (#35058) 2026-04-13 11:56:34 +00:00
f0266e13c5 refactor: improve type annotations in HitTestingService (#27838)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-13 10:31:31 +00:00
ae898652b2 refactor: move vdb implementations to workspaces (#34900)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
2026-04-13 08:56:43 +00:00
c34f67495c refactor(api): type WorkflowRun.to_dict with WorkflowRunDict TypedDict (#35047)
Co-authored-by: Ke Wang <ke@pika.art>
2026-04-13 08:30:28 +00:00
815c536e05 fix: optimize trigger long running read transactions (#35046) 2026-04-13 08:22:54 +00:00
fc64427ae1 fix: fix qdrant delete size is too large (#35042)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 07:59:06 +00:00
11c518478e test: migrate AudioService TTS message-ID lookup tests to Testcontainers integration tests (#34992)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-13 06:26:43 +00:00
e823635ce1 test: migrate app_dsl_service tests to testcontainers (#34429)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 06:25:28 +00:00
98e74c8fde refactor: migrate MessageAnnotation to TypeBase (#34807) 2026-04-13 06:22:43 +00:00
29bfa33d59 feat: support ttft report to langfuse (#33344)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-04-13 06:21:58 +00:00
3ead0beeb1 fix: correct typo submision to submission (#33435)
Co-authored-by: LincolnBurrows2017 <lincoln@example.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-13 05:59:21 +00:00
2108c44c8b refactor(api): consolidate duplicate RerankingModelConfig and WeightedScoreConfig definitions (#34747)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 05:53:45 +00:00
1278 changed files with 19905 additions and 25509 deletions

View File

@ -0,0 +1,79 @@
---
name: e2e-cucumber-playwright
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
---
# Dify E2E Cucumber + Playwright
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
## Scope
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
- Do not use this skill for backend test or API review tasks under `api/`.
## Read Order
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
2. Read only the files directly involved in the task:
- target `.feature` files under `e2e/features/`
- related step files under `e2e/features/step-definitions/`
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
## Local Rules
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
- Browser session behavior comes from `features/support/hooks.ts`:
- default: authenticated session with shared storage state
- `@unauthenticated`: clean browser context
- `@authenticated`: readability/selective-run tag only unless implementation changes
- `@fresh`: only for `e2e:full*` flows
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
## Workflow
1. Rebuild local context.
- Inspect the target feature area.
- Reuse an existing step when wording and behavior already match.
- Add a new step only for a genuinely new user action or assertion.
- Keep edits close to the current capability folder unless the step is broadly reusable.
2. Write behavior-first scenarios.
- Describe user-observable behavior, not DOM mechanics.
- Keep each scenario focused on one workflow or outcome.
- Keep scenarios independent and re-runnable.
3. Write step definitions in the local style.
- Keep one step to one user-visible action or one assertion.
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
- Scope locators to stable containers when the page has repeated elements.
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
4. Use Playwright in the local style.
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
- Use web-first `expect(...)` assertions.
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
5. Validate narrowly.
- Run the narrowest tagged scenario or flow that exercises the change.
- Run `pnpm -C e2e check`.
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
## Review Checklist
- Does the scenario describe behavior rather than implementation?
- Does it fit the current session model, tags, and `DifyWorld` usage?
- Should an existing step be reused instead of adding a new one?
- Are locators user-facing and assertions web-first?
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
- Does it document or implement behavior that differs from the real hooks or configuration?
Lead findings with correctness, flake risk, and architecture drift.
## References
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)

View File

@ -0,0 +1,4 @@
interface:
display_name: "E2E Cucumber + Playwright"
short_description: "Write and review Dify E2E scenarios."
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."

View File

@ -0,0 +1,93 @@
# Cucumber Best Practices For Dify E2E
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
Official sources:
- https://cucumber.io/docs/guides/10-minute-tutorial/
- https://cucumber.io/docs/cucumber/step-definitions/
- https://cucumber.io/docs/cucumber/cucumber-expressions/
## What Matters Most
### 1. Treat scenarios as executable specifications
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
Apply it like this:
- write what the user does and what should happen
- avoid UI-internal wording such as selector details, DOM structure, or component names
- keep language concrete enough that the scenario reads like living documentation
### 2. Keep scenarios focused
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
In Dify's suite, this means:
- one capability-focused scenario per feature path
- no long setup chains when existing bootstrap or reusable steps already cover them
- no hidden dependency on another scenario's side effects
### 3. Reuse steps, but only when behavior really matches
Good reuse reduces duplication. Bad reuse hides meaning.
Prefer reuse when:
- the user action is genuinely the same
- the expected outcome is genuinely the same
- the wording stays natural across features
Write a new step when:
- the behavior is materially different
- reusing the old wording would make the scenario misleading
- a supposedly generic step would become an implementation-detail wrapper
### 4. Prefer Cucumber Expressions
Use Cucumber Expressions for parameters unless regex is clearly necessary.
Common examples:
- `{string}` for labels, names, and visible text
- `{int}` for counts
- `{float}` for decimal values
- `{word}` only when the value is truly a single token
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
### 5. Keep step definitions thin and meaningful
Step definitions are glue between Gherkin and automation, not a second abstraction language.
For Dify:
- type `this` as `DifyWorld`
- use `async function`
- keep each step to one user-visible action or assertion
- rely on `DifyWorld` and existing support code for shared context
- avoid leaking cross-scenario state
### 6. Use tags intentionally
Tags should communicate run scope or session semantics, not become ad hoc metadata.
In Dify's current suite:
- capability tags group related scenarios
- `@unauthenticated` changes session behavior
- `@authenticated` is descriptive/selective, not a behavior switch by itself
- `@fresh` belongs to reset/full-install flows only
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
## Review Questions
- Does the scenario read like a real example of product behavior?
- Are the steps behavior-oriented instead of implementation-oriented?
- Is a reused step still truthful in this feature?
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
- Would a new reader understand the outcome without opening the step-definition file?

View File

@ -0,0 +1,96 @@
# Playwright Best Practices For Dify E2E
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
Official sources:
- https://playwright.dev/docs/best-practices
- https://playwright.dev/docs/locators
- https://playwright.dev/docs/test-assertions
- https://playwright.dev/docs/browser-contexts
## What Matters Most
### 1. Keep scenarios isolated
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
Apply it like this:
- do not depend on another scenario having run first
- do not persist ad hoc scenario state outside `DifyWorld`
- do not couple ordinary scenarios to `@fresh` behavior
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
### 2. Prefer user-facing locators
Playwright recommends built-in locators that reflect what users perceive on the page.
Preferred order in this repository:
1. `getByRole`
2. `getByLabel`
3. `getByPlaceholder`
4. `getByText`
5. `getByTestId` when an explicit test contract is the most stable option
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
Also remember:
- repeated content usually needs scoping to a stable container
- exact text matching is often too brittle when role/name or label already exists
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
### 3. Use web-first assertions
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
Prefer:
- `await expect(page).toHaveURL(...)`
- `await expect(locator).toBeVisible()`
- `await expect(locator).toBeHidden()`
- `await expect(locator).toBeEnabled()`
- `await expect(locator).toHaveText(...)`
Avoid:
- `expect(await locator.isVisible()).toBe(true)`
- custom polling loops for DOM state
- `waitForTimeout` as synchronization
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
### 4. Let actions wait for actionability
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
Good pattern:
- assert a meaningful visible state when that is part of the behavior
- then click/fill/select via locator APIs
Bad pattern:
- stack arbitrary waits before every action
- wait on unstable implementation details instead of the visible state the user cares about
### 5. Match debugging to the current suite
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
- full-page screenshots
- page HTML
- console errors
- page errors
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
## Review Questions
- Would this locator survive DOM refactors that do not change user-visible behavior?
- Is this assertion using Playwright's retrying semantics?
- Is any explicit wait masking a real readiness problem?
- Does this code preserve per-scenario isolation?
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?

View File

@ -0,0 +1 @@
../../.agents/skills/e2e-cucumber-playwright

View File

@ -6,14 +6,7 @@ on:
- "main"
paths:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- packages/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
- .nvmrc
concurrency:
group: docker-build-${{ github.head_ref || github.run_id }}

View File

@ -92,6 +92,7 @@ jobs:
vdb:
- 'api/core/rag/datasource/**'
- 'api/tests/integration_tests/vdb/**'
- 'api/providers/vdb/*/tests/**'
- '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'

View File

@ -62,7 +62,7 @@ jobs:
- name: Render coverage markdown from structured data
id: render
run: |
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"

View File

@ -89,7 +89,7 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@ -81,12 +81,12 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate
api/providers/vdb/vdb-chroma/tests/integration_tests \
api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/providers/vdb/vdb-qdrant/tests/integration_tests \
api/providers/vdb/vdb-weaviate/tests/integration_tests

View File

@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
REDIS_SSL_KEYFILE=
# Path to client private key file for SSL authentication
REDIS_DB=0
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
# Leave empty to preserve current unprefixed behavior.
REDIS_KEY_PREFIX=
# redis Sentinel configuration.
REDIS_USE_SENTINEL=false

View File

@ -69,8 +69,6 @@ ignore = [
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
@ -84,7 +82,6 @@ ignore = [
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
@ -93,29 +90,16 @@ ignore = [
]
[lint.per-file-ignores]
"__init__.py" = [
"F401", # unused-import
"F811", # redefined-while-unused
]
"configs/*" = [
"N802", # invalid-function-name
]
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@ -21,8 +21,9 @@ RUN apt-get update \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies
# Install Python dependencies (workspace members under providers/vdb/)
COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev
# production stage

View File

@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
click.echo(click.style("No dataset collection bindings found.", fg="red"))
return
import qdrant_client
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.")

View File

@ -287,27 +287,6 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for creators platform
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable creators platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client_id for the Creators Platform app registered in Dify",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@ -362,15 +341,6 @@ class FileAccessConfig(BaseSettings):
default="",
)
FILES_API_URL: str = Field(
description="Base URL for storage file ticket API endpoints."
" Used by sandbox containers (internal or external like e2b) that need"
" an absolute, routable address to upload/download files via the API."
" For all-in-one Docker deployments, set to http://localhost."
" For public sandbox environments, set to a public domain or IP.",
default="",
)
FILES_ACCESS_TIMEOUT: int = Field(
description="Expiration time in seconds for file access URLs",
default=300,
@ -1304,52 +1274,6 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class CollaborationConfig(BaseSettings):
ENABLE_COLLABORATION_MODE: bool = Field(
description="Whether to enable collaboration mode features across the workspace",
default=False,
)
class SandboxExpiredRecordsCleanConfig(BaseSettings):
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
description="Graceful period in days for sandbox records clean after subscription expiration",
default=21,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
description="Maximum interval in milliseconds between batches",
default=200,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
description="Lock TTL for sandbox expired records clean task in seconds",
default=90000,
)
class AgentV2UpgradeConfig(BaseSettings):
"""Feature flags for transparent Agent V2 upgrade."""
AGENT_V2_TRANSPARENT_UPGRADE: bool = Field(
description="Transparently run old apps (chat/completion/agent-chat) through the Agent V2 workflow engine. "
"When enabled, old apps synthesize a virtual workflow at runtime instead of using legacy runners.",
default=False,
)
AGENT_V2_REPLACES_LLM: bool = Field(
description="Transparently replace LLM nodes in workflows with Agent V2 nodes at runtime. "
"LLMNodeData is remapped to AgentV2NodeData with tools=[] (identical behavior).",
default=False,
)
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
@ -1419,6 +1343,29 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
)
class SandboxExpiredRecordsCleanConfig(BaseSettings):
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
description="Graceful period in days for sandbox records clean after subscription expiration",
default=21,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
description="Maximum interval in milliseconds between batches",
default=200,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
description="Lock TTL for sandbox expired records clean task in seconds",
default=90000,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1429,7 +1376,6 @@ class FeatureConfig(
AsyncWorkflowConfig,
PluginConfig,
MarketplaceConfig,
CreatorsPlatformConfig,
DataSetConfig,
EndpointConfig,
FileAccessConfig,
@ -1445,6 +1391,7 @@ class FeatureConfig(
PositionConfig,
RagEtlConfig,
RepositoryConfig,
SandboxExpiredRecordsCleanConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,
@ -1452,9 +1399,6 @@ class FeatureConfig(
WorkflowConfig,
WorkflowNodeExecutionConfig,
WorkspaceConfig,
CollaborationConfig,
AgentV2UpgradeConfig,
SandboxExpiredRecordsCleanConfig,
LoginConfig,
AccountConfig,
SwaggerUIConfig,

View File

@ -160,6 +160,16 @@ class DatabaseConfig(BaseSettings):
default="",
)
DB_SESSION_TIMEZONE_OVERRIDE: str = Field(
description=(
"PostgreSQL session timezone override injected via startup options."
" Default is 'UTC' for out-of-the-box consistency."
" Set to empty string to disable app-level timezone injection, for example when using RDS Proxy"
" together with a database-side default timezone."
),
default="UTC",
)
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
@ -227,12 +237,13 @@ class DatabaseConfig(BaseSettings):
connect_args: dict[str, str] = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
timezone_opt = "-c timezone=UTC"
if options:
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
merged_options = options.strip()
session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip()
if session_timezone_override:
timezone_opt = f"-c timezone={session_timezone_override}"
merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt
if merged_options:
connect_args = {"options": merged_options}
result: SQLAlchemyEngineOptionsDict = {
"pool_size": self.SQLALCHEMY_POOL_SIZE,

View File

@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
default=0,
)
REDIS_KEY_PREFIX: str = Field(
description="Optional global prefix for Redis keys, topics, and transport artifacts",
default="",
)
REDIS_USE_SSL: bool = Field(
description="Enable SSL/TLS for the Redis connection",
default=False,

View File

@ -1,4 +1,3 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field
from pydantic_settings import BaseSettings
@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
default="public",
)
HOLOGRES_TOKENIZER: TokenizerType = Field(
HOLOGRES_TOKENIZER: str = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba",
)
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
HOLOGRES_DISTANCE_METHOD: str = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine",
)
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq",
)

View File

@ -1,5 +1,7 @@
"""Configuration for InterSystems IRIS vector database."""
from typing import Any
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validate IRIS configuration values.
Args:

View File

@ -81,20 +81,4 @@ default_app_templates: Mapping[AppMode, Mapping] = {
},
},
},
# agent default mode (new agent backed by single-node workflow)
AppMode.AGENT: {
"app": {
"mode": AppMode.AGENT,
"enable_site": True,
"enable_api": True,
},
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {},
},
},
},
}

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel):
@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource):
@account_initialization_required
def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
prompt_args: AdvancedPromptTemplateArgs = {
"app_mode": args.app_mode,
"model_mode": args.model_mode,
"model_name": args.model_name,
"has_context": args.has_context,
}
return AdvancedPromptTemplateService.get_prompt(prompt_args)

View File

@ -6,10 +6,9 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo
@ -31,13 +30,14 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import build_icon_url
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportMode
from services.entities.dsl_entities import ImportMode, ImportStatus
from services.entities.knowledge_entities.knowledge_entities import (
DataSource,
InfoList,
@ -52,7 +52,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"]
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
register_enum_models(console_ns, IconType)
@ -62,7 +62,7 @@ _logger = logging.getLogger(__name__)
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
@ -94,9 +94,7 @@ class AppListQuery(BaseModel):
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"] = Field(
..., description="App mode"
)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -163,15 +161,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
return value
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
class Tag(ResponseModel):
id: str
name: str
@ -294,7 +283,7 @@ class Site(ResponseModel):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
return build_icon_url(self.icon_type, self.icon)
@field_validator("icon_type", mode="before")
@classmethod
@ -344,7 +333,7 @@ class AppPartial(ResponseModel):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
return build_icon_url(self.icon_type, self.icon)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
@ -392,7 +381,7 @@ class AppDetailWithSite(AppDetail):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
return build_icon_url(self.icon_type, self.icon)
class AppPagination(ResponseModel):
@ -634,7 +623,7 @@ class AppCopyApi(Resource):
args = CopyAppPayload.model_validate(console_ns.payload or {})
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
result = import_service.import_app(
@ -647,6 +636,13 @@ class AppCopyApi(Resource):
icon=args.icon,
icon_background=args.icon_background,
)
if result.status == ImportStatus.FAILED:
session.rollback()
return result.model_dump(mode="json"), 400
if result.status == ImportStatus.PENDING:
session.rollback()
return result.model_dump(mode="json"), 202
session.commit()
# Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,6 +1,6 @@
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console.app.wraps import get_app_model
@ -52,8 +52,9 @@ class AppImportApi(Resource):
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with sessionmaker(db.engine).begin() as session:
# AppDslService performs internal commits for some creation paths, so use a plain
# Session here instead of nesting it inside sessionmaker(...).begin().
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Import app
account = current_user
@ -69,6 +70,10 @@ class AppImportApi(Resource):
icon_background=args.icon_background,
app_id=args.app_id,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
@ -95,12 +100,15 @@ class AppImportConfirmApi(Resource):
# Check user role first
current_user, _ = current_account_with_tenant()
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@ -117,7 +125,7 @@ class AppImportCheckDependenciesApi(Resource):
@account_initialization_required
@edit_permission_required
def get(self, app_model: App):
with sessionmaker(db.engine).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@ -161,7 +161,7 @@ class ChatMessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model):
args_model = ChatMessagePayload.model_validate(console_ns.payload)
@ -215,7 +215,7 @@ class ChatMessageStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")

View File

@ -1,44 +1,86 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.conversation_variable_fields import (
conversation_variable_fields,
paginated_conversation_variable_fields,
)
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
# For nested models, need to replace nested dict with registered model
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
paginated_conversation_variable_fields_copy["data"] = fields.List(
fields.Nested(conversation_variable_model), attribute="data"
)
paginated_conversation_variable_model = console_ns.model(
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
class ConversationVariableResponse(ResponseModel):
id: str
name: str
value_type: str
value: str | None = None
description: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("value_type", mode="before")
@classmethod
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
if isinstance(value, str):
return value
try:
return serialize_value_type(value)
except Exception:
return serialize_value_type({"value_type": value})
@field_validator("value", mode="before")
@classmethod
def _normalize_value(cls, value: Any | None) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class PaginatedConversationVariableResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[ConversationVariableResponse]
register_schema_models(
console_ns,
ConversationVariablesQuery,
ConversationVariableResponse,
PaginatedConversationVariableResponse,
)
@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource):
@console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@console_ns.response(
200,
"Conversation variables retrieved successfully",
console_ns.models[PaginatedConversationVariableResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
rows = session.scalars(stmt).all()
return {
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
],
}
response = PaginatedConversationVariableResponse.model_validate(
{
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
ConversationVariableResponse.model_validate(
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
)
for row in rows
],
}
)
return response.model_dump(mode="json")

View File

@ -26,13 +26,13 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")

View File

@ -1,8 +1,9 @@
import logging
from datetime import datetime
from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, func, select
@ -25,10 +26,21 @@ from controllers.console.wraps import (
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from extensions.ext_database import db
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from fields.base import ResponseModel
from fields.conversation_fields import (
AgentThought,
ConversationAnnotation,
ConversationAnnotationHitHistory,
Feedback,
JSONValue,
MessageFile,
format_files_contained,
to_timestamp,
)
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
@ -98,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel):
data: list[str] = Field(description="Suggested question")
class MessageDetailResponse(ResponseModel):
id: str
conversation_id: str
inputs: dict[str, JSONValue]
query: str
message: JSONValue | None = None
message_tokens: int | None = None
answer: str = Field(validation_alias="re_sign_file_url_answer")
answer_tokens: int | None = None
provider_response_latency: float | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
feedbacks: list[Feedback] = Field(default_factory=list)
workflow_run_id: str | None = None
annotation: ConversationAnnotation | None = None
annotation_hit_history: ConversationAnnotationHitHistory | None = None
created_at: int | None = None
agent_thoughts: list[AgentThought] = Field(default_factory=list)
message_files: list[MessageFile] = Field(default_factory=list)
extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list)
metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict")
status: str
error: str | None = None
parent_message_id: str | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class MessageInfiniteScrollPaginationResponse(ResponseModel):
limit: int
has_more: bool
data: list[MessageDetailResponse]
register_schema_models(
console_ns,
ChatMessagesQuery,
@ -105,124 +162,8 @@ register_schema_models(
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model(
"SimpleAccount",
{
"id": fields.String,
"name": fields.String,
"email": fields.String,
},
)
message_file_model = console_ns.model(
"MessageFile",
{
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
},
)
agent_thought_model = console_ns.model(
"AgentThought",
{
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
},
)
# Models that depend on simple_account_model
feedback_model = console_ns.model(
"Feedback",
{
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_model, allow_null=True),
},
)
annotation_model = console_ns.model(
"Annotation",
{
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
annotation_hit_history_model = console_ns.model(
"AnnotationHitHistory",
{
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
# Message detail model that depends on multiple models
message_detail_model = console_ns.model(
"MessageDetail",
{
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_model)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_model, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"extra_contents": fields.List(fields.Raw),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
},
)
# Message infinite scroll pagination model
message_infinite_scroll_pagination_model = console_ns.model(
"MessageInfiniteScrollPagination",
{
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_model)),
},
MessageDetailResponse,
MessageInfiniteScrollPaginationResponse,
)
@ -232,13 +173,12 @@ class ChatMessageListApi(Resource):
@console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__])
@console_ns.response(404, "Conversation not found")
@login_required
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@marshal_with(message_infinite_scroll_pagination_model)
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
@ -298,7 +238,10 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
attach_message_extra_contents(history_messages)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
return MessageInfiniteScrollPaginationResponse.model_validate(
InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more),
from_attributes=True,
).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
@ -393,7 +336,7 @@ class MessageSuggestedQuestionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id = str(message_id)
@ -468,13 +411,12 @@ class MessageApi(Resource):
@console_ns.doc("get_message")
@console_ns.doc(description="Get message details by ID")
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
@console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__])
@console_ns.response(404, "Message not found")
@get_app_model
@setup_required
@login_required
@account_initialization_required
@marshal_with(message_detail_model)
def get(self, app_model, message_id: str):
message_id = str(message_id)
@ -486,4 +428,4 @@ class MessageApi(Resource):
raise NotFound("Message Not Exists.")
attach_message_extra_contents([message])
return message
return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json")

View File

@ -4,7 +4,7 @@ from collections.abc import Sequence
from typing import Any
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource, fields, marshal, marshal_with
from graphon.enums import NodeType
from graphon.file import File
from graphon.graph_engine.manager import GraphEngineManager
@ -206,7 +206,7 @@ class DraftWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_model)
@edit_permission_required
def get(self, app_model: App):
@ -226,7 +226,7 @@ class DraftWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@console_ns.doc("sync_draft_workflow")
@console_ns.doc(description="Sync draft workflow configuration")
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
@ -310,7 +310,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App):
"""
@ -356,7 +356,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
@ -432,7 +432,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
@ -534,7 +534,7 @@ class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
@ -563,7 +563,7 @@ class AdvancedChatDraftHumanInputFormRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
@ -718,7 +718,7 @@ class WorkflowTaskStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, task_id: str):
"""
@ -746,7 +746,7 @@ class DraftWorkflowNodeRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_model)
@edit_permission_required
def post(self, app_model: App, node_id: str):
@ -792,7 +792,7 @@ class PublishedWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_model)
@edit_permission_required
def get(self, app_model: App):
@ -810,7 +810,7 @@ class PublishedWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
@ -854,7 +854,7 @@ class DefaultBlockConfigsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App):
"""
@ -876,7 +876,7 @@ class DefaultBlockConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App, block_type: str):
"""
@ -941,8 +941,7 @@ class PublishedAllWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@marshal_with(workflow_pagination_model)
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App):
"""
@ -970,9 +969,10 @@ class PublishedAllWorkflowApi(Resource):
user_id=user_id,
named_only=named_only,
)
serialized_workflows = marshal(workflows, workflow_fields_copy)
return {
"items": workflows,
"items": serialized_workflows,
"page": page,
"limit": limit,
"has_more": has_more,
@ -990,7 +990,7 @@ class DraftWorkflowRestoreApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, workflow_id: str):
current_user, _ = current_account_with_tenant()
@ -1028,7 +1028,7 @@ class WorkflowByIdApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_model)
@edit_permission_required
def patch(self, app_model: App, workflow_id: str):
@ -1068,7 +1068,7 @@ class WorkflowByIdApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def delete(self, app_model: App, workflow_id: str):
"""
@ -1103,7 +1103,7 @@ class DraftWorkflowNodeLastRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_model)
def get(self, app_model: App, node_id: str):
srv = WorkflowService()

View File

@ -1,27 +1,26 @@
from datetime import datetime
from typing import Any
from dateutil.parser import isoparse
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.workflow_app_log_fields import (
build_workflow_app_log_pagination_model,
build_workflow_archived_log_pagination_model,
)
from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser
from fields.member_fields import SimpleAccount
from libs.login import login_required
from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel):
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class WorkflowRunForLogResponse(ResponseModel):
id: str
version: str | None = None
status: str | None = None
triggered_from: str | None = None
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
exceptions_count: int | None = None
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
@field_validator("status", mode="before")
@classmethod
def _normalize_status(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowRunForArchivedLogResponse(ResponseModel):
id: str
status: str | None = None
triggered_from: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
@field_validator("status", mode="before")
@classmethod
def _normalize_status(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: Any = None
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowArchivedLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForArchivedLogResponse | None = None
trigger_metadata: Any = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowAppLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowAppLogPartialResponse]
class WorkflowArchivedLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowArchivedLogPartialResponse]
register_schema_models(
console_ns,
WorkflowAppLogQuery,
WorkflowRunForLogResponse,
WorkflowRunForArchivedLogResponse,
WorkflowAppLogPartialResponse,
WorkflowArchivedLogPartialResponse,
WorkflowAppLogPaginationResponse,
WorkflowArchivedLogPaginationResponse,
)
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource):
@console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@console_ns.response(
200,
"Workflow app logs retrieved successfully",
console_ns.models[WorkflowAppLogPaginationResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_app_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow app logs
@ -87,7 +189,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource):
created_by_account=args.created_by_account,
)
return workflow_app_log_pagination
return WorkflowAppLogPaginationResponse.model_validate(
workflow_app_log_pagination, from_attributes=True
).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource):
@console_ns.doc(description="Get workflow archived execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
@console_ns.response(
200,
"Workflow archived logs retrieved successfully",
console_ns.models[WorkflowArchivedLogPaginationResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_archived_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow archived logs
@ -124,7 +231,7 @@ class WorkflowArchivedLogApi(Resource):
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,
@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource):
limit=args.limit,
)
return workflow_app_log_pagination
return WorkflowArchivedLogPaginationResponse.model_validate(
workflow_app_log_pagination, from_attributes=True
).model_dump(mode="json")

View File

@ -1,322 +0,0 @@
import logging
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, TypeAdapter
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.member_fields import AccountWithRole
from fields.workflow_comment_fields import (
workflow_comment_basic_fields,
workflow_comment_create_fields,
workflow_comment_detail_fields,
workflow_comment_reply_create_fields,
workflow_comment_reply_update_fields,
workflow_comment_resolve_fields,
workflow_comment_update_fields,
)
from libs.login import current_user, login_required
from models import App
from services.account_service import TenantService
from services.workflow_comment_service import WorkflowCommentService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowCommentCreatePayload(BaseModel):
position_x: float = Field(..., description="Comment X position")
position_y: float = Field(..., description="Comment Y position")
content: str = Field(..., description="Comment content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentUpdatePayload(BaseModel):
content: str = Field(..., description="Comment content")
position_x: float | None = Field(default=None, description="Comment X position")
position_y: float | None = Field(default=None, description="Comment Y position")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentReplyCreatePayload(BaseModel):
content: str = Field(..., description="Reply content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentReplyUpdatePayload(BaseModel):
content: str = Field(..., description="Reply content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentMentionUsersResponse(BaseModel):
users: list[AccountWithRole] = Field(description="Mentionable users")
for model in (
WorkflowCommentCreatePayload,
WorkflowCommentUpdatePayload,
WorkflowCommentReplyCreatePayload,
WorkflowCommentReplyUpdatePayload,
):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
for model in (AccountWithRole, WorkflowCommentMentionUsersResponse):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
workflow_comment_reply_create_model = console_ns.model(
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
)
workflow_comment_reply_update_model = console_ns.model(
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
)
workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__]
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
class WorkflowCommentListApi(Resource):
"""API for listing and creating workflow comments."""
@console_ns.doc("list_workflow_comments")
@console_ns.doc(description="Get all comments for a workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_basic_model, envelope="data")
def get(self, app_model: App):
"""Get all comments for a workflow."""
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
return comments
@console_ns.doc("create_workflow_comment")
@console_ns.doc(description="Create a new workflow comment")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_create_model)
def post(self, app_model: App):
"""Create a new workflow comment."""
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.create_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
created_by=current_user.id,
content=payload.content,
position_x=payload.position_x,
position_y=payload.position_y,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
class WorkflowCommentDetailApi(Resource):
"""API for managing individual workflow comments."""
@console_ns.doc("get_workflow_comment")
@console_ns.doc(description="Get a specific workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_detail_model)
def get(self, app_model: App, comment_id: str):
"""Get a specific workflow comment."""
comment = WorkflowCommentService.get_comment(
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
)
return comment
@console_ns.doc("update_workflow_comment")
@console_ns.doc(description="Update a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_update_model)
def put(self, app_model: App, comment_id: str):
"""Update a workflow comment."""
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.update_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
content=payload.content,
position_x=payload.position_x,
position_y=payload.position_y,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result
@console_ns.doc("delete_workflow_comment")
@console_ns.doc(description="Delete a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(204, "Comment deleted successfully")
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def delete(self, app_model: App, comment_id: str):
"""Delete a workflow comment."""
WorkflowCommentService.delete_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
class WorkflowCommentResolveApi(Resource):
"""API for resolving and reopening workflow comments."""
@console_ns.doc("resolve_workflow_comment")
@console_ns.doc(description="Resolve a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_resolve_model)
def post(self, app_model: App, comment_id: str):
"""Resolve a workflow comment."""
comment = WorkflowCommentService.resolve_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return comment
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
class WorkflowCommentReplyApi(Resource):
"""API for managing comment replies."""
@console_ns.doc("create_workflow_comment_reply")
@console_ns.doc(description="Add a reply to a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_create_model)
def post(self, app_model: App, comment_id: str):
"""Add a reply to a workflow comment."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.create_reply(
comment_id=comment_id,
content=payload.content,
created_by=current_user.id,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
class WorkflowCommentReplyDetailApi(Resource):
"""API for managing individual comment replies."""
@console_ns.doc("update_workflow_comment_reply")
@console_ns.doc(description="Update a comment reply")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_update_model)
def put(self, app_model: App, comment_id: str, reply_id: str):
"""Update a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
reply = WorkflowCommentService.update_reply(
reply_id=reply_id,
user_id=current_user.id,
content=payload.content,
mentioned_user_ids=payload.mentioned_user_ids,
)
return reply
@console_ns.doc("delete_workflow_comment_reply")
@console_ns.doc(description="Delete a comment reply")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
@console_ns.response(204, "Reply deleted successfully")
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def delete(self, app_model: App, comment_id: str, reply_id: str):
"""Delete a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
class WorkflowCommentMentionUsersApi(Resource):
"""API for getting mentionable users for workflow comments."""
@console_ns.doc("workflow_comment_mention_users")
@console_ns.doc(description="Get all users in current tenant for mentions")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App):
"""Get all users in current tenant for mentions."""
members = TenantService.get_tenant_members(current_user.current_tenant)
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = WorkflowCommentMentionUsersResponse(users=member_models)
return response.model_dump(mode="json"), 200

View File

@ -216,7 +216,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(*args, **kwargs)

View File

@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
def _build_backstage_input_url(form_token: str | None) -> str | None:
@ -207,14 +207,18 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(advanced_chat_workflow_run_pagination_model)
def get(self, app_model: App):
"""
Get advanced chat app workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
# Default to DEBUGGING if not specified
triggered_from = (
@ -305,7 +309,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(workflow_run_count_model)
def get(self, app_model: App):
"""
@ -349,14 +353,18 @@ class WorkflowRunListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_pagination_model)
def get(self, app_model: App):
"""
Get workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
@ -397,7 +405,7 @@ class WorkflowRunCountApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_count_model)
def get(self, app_model: App):
"""
@ -434,7 +442,7 @@ class WorkflowRunDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_detail_model)
def get(self, app_model: App, run_id):
"""
@ -458,7 +466,7 @@ class WorkflowRunNodeExecutionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_list_model)
def get(self, app_model: App, run_id):
"""

View File

@ -1,16 +1,17 @@
import logging
from datetime import datetime
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from flask_restx import Resource
from pydantic import BaseModel, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import get_or_create_model
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from fields.base import ResponseModel
from libs.login import current_user, login_required
from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode
@ -21,15 +22,6 @@ from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
triggers_list_fields_copy = triggers_list_fields.copy()
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
class Parser(BaseModel):
@ -41,10 +33,52 @@ class ParserEnable(BaseModel):
enable_trigger: bool
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class WorkflowTriggerResponse(ResponseModel):
id: str
trigger_type: str
title: str
node_id: str
provider_name: str
icon: str
status: str
created_at: datetime | None = None
updated_at: datetime | None = None
console_ns.schema_model(
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
@field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before")
@classmethod
def _normalize_string_fields(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
class WorkflowTriggerListResponse(ResponseModel):
data: list[WorkflowTriggerResponse]
class WebhookTriggerResponse(ResponseModel):
id: str
webhook_id: str
webhook_url: str
webhook_debug_url: str
node_id: str
created_at: datetime | None = None
@field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before")
@classmethod
def _normalize_string_fields(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
register_schema_models(
console_ns,
Parser,
ParserEnable,
WorkflowTriggerResponse,
WorkflowTriggerListResponse,
WebhookTriggerResponse,
)
@ -57,14 +91,14 @@ class WebhookTriggerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_model)
@console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__])
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
node_id = args.node_id
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get webhook trigger for this app and node
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger)
@ -78,7 +112,7 @@ class WebhookTriggerApi(Resource):
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
return webhook_trigger
return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/triggers")
@ -89,13 +123,13 @@ class AppTriggersApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_model)
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
@ -118,7 +152,9 @@ class AppTriggersApi(Resource):
else:
trigger.icon = "" # type: ignore
return {"data": triggers}
return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump(
mode="json"
)
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_model)
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)
@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource):
else:
trigger.icon = "" # type: ignore
return trigger
return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json")

View File

@ -1,3 +1,5 @@
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
@ -40,7 +42,7 @@ class ActivatePayload(BaseModel):
class ActivationCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether token is valid")
data: dict | None = Field(default=None, description="Activation data if valid")
data: dict[str, Any] | None = Field(default=None, description="Activation data if valid")
class ActivationResponse(BaseModel):

View File

@ -1,7 +1,10 @@
import logging
import flask_login
from flask import make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
@ -42,12 +45,13 @@ from libs.token import (
)
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginPayloadBase
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase):
@ -91,10 +95,12 @@ class LoginApi(Resource):
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
@ -110,14 +116,20 @@ class LoginApi(Resource):
invitee_email = data.get("email") if data else None
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
_log_console_login_failure(
email=normalized_email,
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
)
raise InvalidEmailError()
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
except services.errors.account.AccountLoginError:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@ -240,20 +252,27 @@ class EmailCodeLoginApi(Resource):
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
if token_data["code"] != args.code:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args.token)
try:
account = _get_account_with_case_fallback(original_email)
except Unauthorized as exc:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@ -279,6 +298,7 @@ class EmailCodeLoginApi(Resource):
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
@ -336,3 +356,12 @@ def _authenticate_account_with_case_fallback(
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Console login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@ -1026,7 +1026,7 @@ class DocumentMetadataApi(DocumentResource):
if not isinstance(doc_metadata, dict):
raise ValueError("doc_metadata must be a dictionary.")
metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
document.doc_metadata = {}
if doc_type == "others":

View File

@ -1,13 +1,13 @@
from flask_restx import Resource, fields
from __future__ import annotations
from controllers.common.schema import register_schema_model
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
from datetime import datetime
from typing import Any
from flask_restx import Resource
from pydantic import Field, field_validator
from controllers.common.schema import register_schema_models
from fields.base import ResponseModel
from libs.login import login_required
from .. import console_ns
@ -18,39 +18,92 @@ from ..wraps import (
setup_required,
)
register_schema_model(console_ns, HitTestingPayload)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
class HitTestingDocument(ResponseModel):
id: str | None = None
data_source_type: str | None = None
name: str | None = None
doc_type: str | None = None
doc_metadata: Any | None = None
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
class HitTestingSegment(ResponseModel):
id: str | None = None
position: int | None = None
document_id: str | None = None
content: str | None = None
sign_content: str | None = None
answer: str | None = None
word_count: int | None = None
tokens: int | None = None
keywords: list[str] = Field(default_factory=list)
index_node_id: str | None = None
index_node_hash: str | None = None
hit_count: int | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
status: str | None = None
created_by: str | None = None
created_at: int | None = None
indexing_at: int | None = None
completed_at: int | None = None
error: str | None = None
stopped_at: int | None = None
document: HitTestingDocument | None = None
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
class HitTestingChildChunk(ResponseModel):
id: str | None = None
content: str | None = None
position: int | None = None
score: float | None = None
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
class HitTestingFile(ResponseModel):
id: str | None = None
name: str | None = None
size: int | None = None
extension: str | None = None
mime_type: str | None = None
source_url: str | None = None
class HitTestingRecord(ResponseModel):
segment: HitTestingSegment | None = None
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
score: float | None = None
tsne_position: Any | None = None
files: list[HitTestingFile] = Field(default_factory=list)
summary: str | None = None
class HitTestingResponse(ResponseModel):
query: str
records: list[HitTestingRecord] = Field(default_factory=list)
register_schema_models(
console_ns,
HitTestingPayload,
HitTestingDocument,
HitTestingSegment,
HitTestingChildChunk,
HitTestingFile,
HitTestingRecord,
HitTestingResponse,
)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(
200,
"Hit testing completed successfully",
model=console_ns.models[HitTestingResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required
@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")

View File

@ -1,21 +1,24 @@
import logging
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from graphon.file import helpers as file_helpers
from pydantic import BaseModel, Field, computed_field, field_validator
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.common.schema import get_or_create_model
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from models.model import IconType
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel):
logger = logging.getLogger(__name__)
app_model = get_or_create_model("InstalledAppInfo", app_fields)
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
installed_app_fields_copy = installed_app_fields.copy()
installed_app_fields_copy["app"] = fields.Nested(app_model)
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
installed_app_list_fields_copy = installed_app_list_fields.copy()
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
def _safe_primitive(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool, datetime)):
return value
return None
class InstalledAppInfoResponse(ResponseModel):
id: str
name: str | None = None
mode: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
use_icon_as_answer_icon: bool | None = None
@field_validator("mode", "icon_type", mode="before")
@classmethod
def _normalize_enum_like(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
class InstalledAppResponse(ResponseModel):
id: str
app: InstalledAppInfoResponse
app_owner_tenant_id: str
is_pinned: bool
last_used_at: int | None = None
editable: bool
uninstallable: bool
@field_validator("app", mode="before")
@classmethod
def _normalize_app(cls, value: Any) -> Any:
if isinstance(value, dict):
return value
return {
"id": _safe_primitive(getattr(value, "id", "")) or "",
"name": _safe_primitive(getattr(value, "name", None)),
"mode": _safe_primitive(getattr(value, "mode", None)),
"icon_type": _safe_primitive(getattr(value, "icon_type", None)),
"icon": _safe_primitive(getattr(value, "icon", None)),
"icon_background": _safe_primitive(getattr(value, "icon_background", None)),
"use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)),
}
@field_validator("last_used_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class InstalledAppListResponse(ResponseModel):
installed_apps: list[InstalledAppResponse]
register_schema_models(
console_ns,
InstalledAppCreatePayload,
InstalledAppUpdatePayload,
InstalledAppsListQuery,
InstalledAppInfoResponse,
InstalledAppResponse,
InstalledAppListResponse,
)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(installed_app_list_model)
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
def get(self):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource):
)
)
return {"installed_apps": installed_app_list}
return InstalledAppListResponse.model_validate(
{"installed_apps": installed_app_list}, from_attributes=True
).model_dump(mode="json")
@login_required
@account_initialization_required

View File

@ -1,66 +1,83 @@
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, computed_field, field_validator
from constants.languages import languages
from controllers.common.schema import get_or_create_model
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from fields.base import ResponseModel
from libs.helper import build_icon_url
from libs.login import current_user, login_required
from services.recommended_app_service import RecommendedAppService
app_fields = {
"id": fields.String,
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_type": fields.String,
"icon_url": AppIconUrlField,
"icon_background": fields.String,
}
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
recommended_app_fields = {
"app": fields.Nested(app_model, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
recommended_app_list_fields = {
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
"categories": fields.List(fields.String),
}
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
console_ns.schema_model(
RecommendedAppsQuery.__name__,
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
class RecommendedAppInfoResponse(ResponseModel):
id: str
name: str | None = None
mode: str | None = None
icon: str | None = None
icon_type: str | None = None
icon_background: str | None = None
@staticmethod
def _normalize_enum_like(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("mode", "icon_type", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> str | None:
return cls._normalize_enum_like(value)
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def icon_url(self) -> str | None:
return build_icon_url(self.icon_type, self.icon)
class RecommendedAppResponse(ResponseModel):
app: RecommendedAppInfoResponse | None = None
app_id: str
description: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
category: str | None = None
position: int | None = None
is_listed: bool | None = None
can_trial: bool | None = None
class RecommendedAppListResponse(ResponseModel):
recommended_apps: list[RecommendedAppResponse]
categories: list[str]
register_schema_models(
console_ns,
RecommendedAppsQuery,
RecommendedAppInfoResponse,
RecommendedAppResponse,
RecommendedAppListResponse,
)
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_model)
def get(self):
# language args
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource):
else:
language_prefix = languages[0]
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
return RecommendedAppListResponse.model_validate(
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
from_attributes=True,
).model_dump(mode="json")
@console_ns.route("/explore/apps/<uuid:app_id>")

View File

@ -169,6 +169,7 @@ console_ns.schema_model(
class TrialAppWorkflowRunApi(TrialAppResource):
@trial_feature_enable
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
def post(self, trial_app):
"""
@ -210,6 +211,7 @@ class TrialAppWorkflowRunApi(TrialAppResource):
class TrialAppWorkflowTaskStopApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app, task_id: str):
"""
Stop workflow task
@ -290,7 +292,6 @@ class TrialChatApi(TrialAppResource):
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
@ -470,7 +471,6 @@ class TrialCompletionApi(TrialAppResource):
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model_with_trial(None)
def get(self, app_model):
"""Retrieve app site info.
@ -492,7 +492,6 @@ class TrialSitApi(Resource):
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model_with_trial(None)
def get(self, app_model):
"""Retrieve app parameters."""
@ -521,7 +520,6 @@ class TrialAppParameterApi(Resource):
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial(None)
@marshal_with(app_detail_with_site_model)
def get(self, app_model):
@ -534,7 +532,6 @@ class AppApi(Resource):
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial(None)
@marshal_with(workflow_model)
def get(self, app_model):
@ -547,7 +544,6 @@ class AppWorkflowApi(Resource):
class DatasetListApi(Resource):
@trial_feature_enable
@get_app_model_with_trial(None)
def get(self, app_model):
page = request.args.get("page", default=1, type=int)

View File

@ -1,15 +1,18 @@
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from constants import HIDDEN_VALUE
from fields.api_based_extension_fields import api_based_extension_fields
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import register_schema_models
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel):
api_key: str = Field(description="API key for authentication")
register_schema_models(console_ns, APIBasedExtensionPayload)
class CodeBasedExtensionResponse(ResponseModel):
module: str = Field(description="Module name")
data: Any = Field(description="Extension data")
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
def _mask_api_key(api_key: str) -> str:
if not api_key:
return api_key
if len(api_key) <= 8:
return api_key[0] + "******" + api_key[-1]
return api_key[:3] + "******" + api_key[-3:]
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class APIBasedExtensionResponse(ResponseModel):
id: str
name: str
api_endpoint: str
api_key: str
created_at: int | None = None
@field_validator("api_key", mode="before")
@classmethod
def _normalize_api_key(cls, value: str) -> str:
return _mask_api_key(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)
console_ns.schema_model(
"APIBasedExtensionListResponse",
TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]:
return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json")
@console_ns.route("/code-based-extension")
@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"CodeBasedExtensionResponse",
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
),
console_ns.models[CodeBasedExtensionResponse.__name__],
)
@setup_required
@login_required
@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource):
def get(self):
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
return CodeBasedExtensionResponse(
module=query.module,
data=CodeBasedExtensionService.get_code_based_extension(query.module),
).model_dump(mode="json")
@console_ns.route("/api-based-extension")
class APIBasedExtensionAPI(Resource):
@console_ns.doc("get_api_based_extensions")
@console_ns.doc(description="Get all API-based extensions for current tenant")
@console_ns.response(200, "Success", api_based_extension_list_model)
@console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def get(self):
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
return [
_serialize_api_based_extension(extension)
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
]
@console_ns.doc("create_api_based_extension")
@console_ns.doc(description="Create a new API-based extension")
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
@console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self):
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource):
api_key=payload.api_key,
)
return APIBasedExtensionService.save(extension_data)
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data))
@console_ns.route("/api-based-extension/<uuid:id>")
@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource):
@console_ns.doc("get_api_based_extension")
@console_ns.doc(description="Get API-based extension by ID")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.response(200, "Success", api_based_extension_model)
@console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def get(self, id):
api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
return _serialize_api_based_extension(
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
)
@console_ns.doc("update_api_based_extension")
@console_ns.doc(description="Update API-based extension")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
@console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self, id):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource):
if payload.api_key != HIDDEN_VALUE:
extension_data_from_db.api_key = payload.api_key
return APIBasedExtensionService.save(extension_data_from_db)
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db))
@console_ns.doc("delete_api_based_extension")
@console_ns.doc(description="Delete API-based extension")

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import TypedDict
from flask import request
@ -13,6 +14,14 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US"
class NotificationLangContent(TypedDict, total=False):
lang: str
title: str
subtitle: str
body: str
titlePicUrl: str
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict:
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
return (
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
)
class DismissNotificationPayload(BaseModel):
@ -71,7 +82,7 @@ class NotificationApi(Resource):
notifications: list[NotificationItemDict] = []
for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {}
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = {
"notification_id": notification.get("notificationId"),

View File

@ -1,119 +0,0 @@
import logging
from collections.abc import Callable
from typing import cast
from flask import Request as FlaskRequest
from extensions.ext_socketio import sio
from libs.passport import PassportService
from libs.token import extract_access_token
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
from services.account_service import AccountService
from services.workflow_collaboration_service import WorkflowCollaborationService
repository = WorkflowCollaborationRepository()
collaboration_service = WorkflowCollaborationService(repository, sio)
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
@_sio_on("connect")
def socket_connect(sid, environ, auth):
"""
WebSocket connect event, do authentication here.
"""
try:
request_environ = FlaskRequest(environ)
token = extract_access_token(request_environ)
except Exception:
logging.exception("Failed to extract token")
token = None
if not token:
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
return False
try:
decoded = PassportService().verify(token)
user_id = decoded.get("user_id")
if not user_id:
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
return False
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
if not user:
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
return False
if not user.has_edit_permission:
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
return False
collaboration_service.save_session(sid, user)
return True
except Exception:
logging.exception("Socket authentication failed")
return False
@_sio_on("user_connect")
def handle_user_connect(sid, data):
"""
Handle user connect event. Each session (tab) is treated as an independent collaborator.
"""
workflow_id = data.get("workflow_id")
if not workflow_id:
return {"msg": "workflow_id is required"}, 400
result = collaboration_service.register_session(workflow_id, sid)
if not result:
return {"msg": "unauthorized"}, 401
user_id, is_leader = result
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
@_sio_on("disconnect")
def handle_disconnect(sid):
"""
Handle session disconnect event. Remove the specific session from online users.
"""
collaboration_service.disconnect_session(sid)
@_sio_on("collaboration_event")
def handle_collaboration_event(sid, data):
"""
Handle general collaboration events, include:
1. mouse_move
2. vars_and_features_update
3. sync_request (ask leader to update graph)
4. app_state_update
5. mcp_server_update
6. workflow_update
7. comments_update
8. node_panel_presence
9. skill_file_active
10. skill_sync_request
11. skill_resync_request
"""
return collaboration_service.relay_collaboration_event(sid, data)
@_sio_on("graph_event")
def handle_graph_event(sid, data):
"""
Handle graph events - simple broadcast relay.
"""
return collaboration_service.relay_graph_event(sid, data)
@_sio_on("skill_event")
def handle_skill_event(sid, data):
"""
Handle skill events - simple broadcast relay.
"""
return collaboration_service.relay_skill_event(sid, data)

View File

@ -1,13 +1,14 @@
from typing import Literal
from flask import request
from flask_restx import Namespace, Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.enums import TagType
from services.tag_service import (
@ -18,17 +19,6 @@ from services.tag_service import (
UpdateTagPayload,
)
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
@ -52,12 +42,36 @@ class TagListQueryParam(BaseModel):
keyword: str | None = Field(None, description="Search keyword")
class TagResponse(ResponseModel):
id: str
name: str
type: str | None = None
binding_count: str | None = None
@field_validator("type", mode="before")
@classmethod
def normalize_type(cls, value: TagType | str | None) -> str | None:
if value is None:
return None
if isinstance(value, TagType):
return value.value
return value
@field_validator("binding_count", mode="before")
@classmethod
def normalize_binding_count(cls, value: int | str | None) -> str | None:
if value is None:
return None
return str(value)
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagListQueryParam,
TagResponse,
)
@ -69,14 +83,18 @@ class TagListApi(Resource):
@console_ns.doc(
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
)
@marshal_with(dataset_tag_fields)
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
def get(self):
_, current_tenant_id = current_account_with_tenant()
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
return tags, 200
serialized_tags = [
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags
]
return serialized_tags, 200
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@ -91,7 +109,9 @@ class TagListApi(Resource):
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
return response, 200
@ -114,7 +134,9 @@ class TagUpdateDeleteApi(Resource):
binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
return response, 200

View File

@ -35,22 +35,24 @@ def plugin_permission_required(
return view(*args, **kwargs)
if install_required:
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
match permission.install_permission:
case TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
pass
case TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.InstallPermission.EVERYONE:
pass
if debug_required:
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
match permission.debug_permission:
case TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
pass
case TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.DebugPermission.EVERYONE:
pass
return view(*args, **kwargs)

View File

@ -1,11 +1,11 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from typing import Any, Literal
import pytz
from flask import request
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
@ -37,9 +37,10 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.member_fields import Account as AccountResponse
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.helper import EmailStr, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
@ -174,21 +175,61 @@ reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict:
def _serialize_account(account) -> dict[str, Any]:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
integrate_fields = {
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"link": fields.String,
}
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
integrate_list_model = console_ns.model(
"AccountIntegrateList",
{"data": fields.List(fields.Nested(integrate_model))},
class AccountIntegrateResponse(ResponseModel):
provider: str
created_at: int | None = None
is_bound: bool
link: str | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountIntegrateListResponse(ResponseModel):
data: list[AccountIntegrateResponse]
class EducationVerifyResponse(ResponseModel):
token: str | None = None
class EducationStatusResponse(ResponseModel):
result: bool | None = None
is_student: bool | None = None
expire_at: int | None = None
allow_refresh: bool | None = None
@field_validator("expire_at", mode="before")
@classmethod
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class EducationAutocompleteResponse(ResponseModel):
data: list[str] = Field(default_factory=list)
curr_page: int | None = None
has_next: bool | None = None
register_schema_models(
console_ns,
AccountIntegrateResponse,
AccountIntegrateListResponse,
EducationVerifyResponse,
EducationStatusResponse,
EducationAutocompleteResponse,
)
@ -359,7 +400,7 @@ class AccountIntegrateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_model)
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@ -395,7 +436,9 @@ class AccountIntegrateApi(Resource):
}
)
return {"data": integrate_data}
return AccountIntegrateListResponse(
data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data]
).model_dump(mode="json")
@console_ns.route("/account/delete/verify")
@ -447,31 +490,22 @@ class AccountDeleteUpdateFeedbackApi(Resource):
@console_ns.route("/account/education/verify")
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
return BillingService.EducationIdentity.verify(account.id, account.email)
return EducationVerifyResponse.model_validate(
BillingService.EducationIdentity.verify(account.id, account.email) or {}
).model_dump(mode="json")
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
"is_student": fields.Boolean,
"expire_at": TimestampField,
"allow_refresh": fields.Boolean,
}
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
@setup_required
@login_required
@ -491,37 +525,33 @@ class EducationApi(Resource):
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(status_fields)
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
res = BillingService.EducationIdentity.status(account.id)
res = BillingService.EducationIdentity.status(account.id) or {}
# convert expire_at to UTC timestamp from isoformat
if res and "expire_at" in res:
res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc)
return res
return EducationStatusResponse.model_validate(res).model_dump(mode="json")
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
"curr_page": fields.Integer,
"has_next": fields.Boolean,
}
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(data_fields)
@console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__])
def get(self):
payload = request.args.to_dict(flat=True)
args = EducationAutocompleteQuery.model_validate(payload)
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
return EducationAutocompleteResponse.model_validate(
BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {}
).model_dump(mode="json")
@console_ns.route("/account/change-email")

View File

@ -1,67 +0,0 @@
import json
import httpx
import yaml
from flask import request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.plugin.impl.exc import PluginPermissionDeniedError
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models.model import App
from models.workflow import Workflow
from services.app_dsl_service import AppDslService
class DSLPredictRequest(BaseModel):
app_id: str
current_node_id: str
@console_ns.route("/workspaces/current/dsl/predict")
class DSLPredictApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user, _ = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
args = DSLPredictRequest.model_validate(request.get_json())
app_id: str = args.app_id
current_node_id: str = args.current_node_id
with Session(db.engine) as session:
app = session.query(App).filter_by(id=app_id).first()
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
if not app:
raise ValueError("App not found")
if not workflow:
raise ValueError("Workflow not found")
try:
i = 0
for node_id, _ in workflow.walk_nodes():
if node_id == current_node_id:
break
i += 1
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
response = httpx.post(
"http://spark-832c:8000/predict",
json={"graph_data": dsl, "source_node_index": i},
)
return {
"nodes": json.loads(response.json()),
}
except PluginPermissionDeniedError as e:
raise ValueError(e.description) from e

View File

@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource):
class ParserValidate(BaseModel):
model: str
model_type: ModelType
credentials: dict
credentials: dict[str, Any]
console_ns.schema_model(

View File

@ -1,8 +1,9 @@
import logging
from datetime import datetime
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@ -26,6 +27,7 @@ from controllers.console.wraps import (
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel):
name: str
class TenantInfoResponse(ResponseModel):
id: str
name: str | None = None
plan: str | None = None
status: str | None = None
created_at: int | None = None
role: str | None = None
in_trial: bool | None = None
trial_end_reason: str | None = None
custom_config: dict | None = None
trial_credits: int | None = None
trial_credits_used: int | None = None
next_credit_reset_date: int | None = None
@field_validator("plan", "status", "trial_end_reason", mode="before")
@classmethod
def _normalize_enum_like(cls, value):
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None):
if isinstance(value, datetime):
return int(value.timestamp())
return value
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@ -66,6 +99,7 @@ reg(WorkspaceListQuery)
reg(SwitchWorkspacePayload)
reg(WorkspaceCustomConfigPayload)
reg(WorkspaceInfoPayload)
reg(TenantInfoResponse)
provider_fields = {
"provider_name": fields.String,
@ -180,7 +214,7 @@ class TenantApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(tenant_fields)
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
def post(self):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
@ -200,7 +234,13 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
return WorkspaceService.get_tenant_info(tenant), 200
return (
TenantInfoResponse.model_validate(
WorkspaceService.get_tenant_info(tenant),
from_attributes=True,
).model_dump(mode="json"),
200,
)
@console_ns.route("/workspaces/switch")

View File

@ -20,7 +20,7 @@ from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
from services.operation_service import OperationService, UtmInfo
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)

View File

@ -1,80 +0,0 @@
"""Token-based file proxy controller for storage operations.
This controller handles file download and upload operations using opaque UUID tokens.
The token maps to the real storage key in Redis, so the actual storage path is never
exposed in the URL.
Routes:
GET /files/storage-files/{token} - Download a file
PUT /files/storage-files/{token} - Upload a file
The operation type (download/upload) is determined by the ticket stored in Redis,
not by the HTTP method. This ensures a download ticket cannot be used for upload
and vice versa.
"""
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource
from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge
from controllers.files import files_ns
from extensions.ext_storage import storage
from services.storage_ticket_service import StorageTicketService
@files_ns.route("/storage-files/<string:token>")
class StorageFilesApi(Resource):
"""Handle file operations through token-based URLs."""
def get(self, token: str):
"""Download a file using a token.
The ticket must have op="download", otherwise returns 403.
"""
ticket = StorageTicketService.get_ticket(token)
if ticket is None:
raise Forbidden("Invalid or expired token")
if ticket.op != "download":
raise Forbidden("This token is not valid for download")
try:
generator = storage.load_stream(ticket.storage_key)
except FileNotFoundError:
raise NotFound("File not found")
filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1]
encoded_filename = quote(filename)
return Response(
generator,
mimetype="application/octet-stream",
direct_passthrough=True,
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
},
)
def put(self, token: str):
"""Upload a file using a token.
The ticket must have op="upload", otherwise returns 403.
If the request body exceeds max_bytes, returns 413.
"""
ticket = StorageTicketService.get_ticket(token)
if ticket is None:
raise Forbidden("Invalid or expired token")
if ticket.op != "upload":
raise Forbidden("This token is not valid for upload")
content = request.get_data()
if ticket.max_bytes is not None and len(content) > ticket.max_bytes:
raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes")
storage.save(ticket.storage_key, content)
return Response(status=204)

View File

@ -9,7 +9,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource):
account.set_tenant_id(workspace_id)
with sessionmaker(db.engine).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
dsl_service = AppDslService(session)
result = dsl_service.import_app(
account=account,
@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource):
name=args.name,
description=args.description,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from graphon.variables.input_entities import VariableEntity, VariableEntityType
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict) -> VariableEntity:
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
"""Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
try:
variable_type = VariableEntityType(variable_type_raw)
except ValueError as e:
raise MCPRequestError(
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
) from e
variable = item[variable_type_raw]
return VariableEntity(
type=variable_type,
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
json_schema=variable.get("json_schema"),
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request"""
try:
return mcp_types.ClientRequest.model_validate(args)

View File

@ -194,7 +194,7 @@ class ChatApi(Resource):
Supports conversation management and both blocking and streaming response modes.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
@ -258,7 +258,7 @@ class ChatStopApi(Resource):
def post(self, app_model: App, end_user: EndUser, task_id: str):
"""Stop a running chat message generation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppTaskService.stop_task(

View File

@ -1,7 +1,9 @@
from datetime import datetime
from typing import Any, Literal
from flask import request
from flask_restx import Resource
from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound
@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
SimpleConversation,
)
from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel):
value: Any
class ConversationVariableResponse(ResponseModel):
id: str
name: str
value_type: str
value: str | None = None
description: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("value_type", mode="before")
@classmethod
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
except ValueError:
return value
try:
return serialize_value_type(value)
except (AttributeError, TypeError, ValueError):
pass
try:
return serialize_value_type({"value_type": value})
except (AttributeError, TypeError, ValueError):
value_attr = getattr(value, "value", None)
if value_attr is not None:
return str(value_attr)
return str(value)
@field_validator("value", mode="before")
@classmethod
def normalize_value(cls, value: Any | None) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
limit: int
has_more: bool
data: list[ConversationVariableResponse]
register_schema_models(
service_api_ns,
ConversationListQuery,
ConversationRenamePayload,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
ConversationVariableResponse,
ConversationVariableInfiniteScrollPaginationResponse,
)
@ -98,7 +156,7 @@ class ConversationApi(Resource):
Supports pagination using last_id and limit parameters.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
query_args = ConversationListQuery.model_validate(request.args.to_dict())
@ -142,7 +200,7 @@ class ConversationDetailApi(Resource):
def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -171,7 +229,7 @@ class ConversationRenameApi(Resource):
def post(self, app_model: App, end_user: EndUser, c_id):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource):
404: "Conversation not found",
}
)
@service_api_ns.response(
200,
"Variables retrieved successfully",
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser, c_id):
"""List all variables for a conversation.
@ -213,7 +275,7 @@ class ConversationVariablesApi(Resource):
"""
# conversational variable only for chat app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource):
last_id = str(query_args.last_id) if query_args.last_id else None
try:
return ConversationService.get_conversational_variable(
pagination = ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
)
return ConversationVariableInfiniteScrollPaginationResponse.model_validate(
pagination, from_attributes=True
).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource):
404: "Conversation or variable not found",
}
)
@service_api_ns.response(
200,
"Variable updated successfully",
service_api_ns.models[ConversationVariableResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
"""Update a conversation variable's value.
@ -252,7 +321,7 @@ class ConversationVariableDetailApi(Resource):
The value must match the variable's expected type.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource):
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.update_conversation_variable(
variable = ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, payload.value
)
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationVariableNotExistsError:

View File

@ -53,7 +53,7 @@ class MessageListApi(Resource):
Retrieves messages with pagination support using first_id.
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
query_args = MessageListQuery.model_validate(request.args.to_dict())
@ -158,7 +158,7 @@ class MessageSuggestedApi(Resource):
"""
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
try:

View File

@ -1,13 +1,15 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Literal
from dateutil.parser import isoparse
from flask import request
from flask_restx import Namespace, Resource, fields
from flask_restx import Resource, fields
from graphon.enums import WorkflowExecutionStatus
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@ -33,9 +35,10 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser
from fields.member_fields import SimpleAccount
from libs import helper
from libs.helper import OptionalTimestampField, TimestampField
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel):
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def _enum_value(value):
return getattr(value, "value", value)
class WorkflowRunStatusField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
return obj.status.value
return _enum_value(obj.status)
class WorkflowRunOutputsField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
if obj.status == WorkflowExecutionStatus.PAUSED:
status = _enum_value(obj.status)
if status == WorkflowExecutionStatus.PAUSED.value:
return {}
outputs = obj.outputs_dict
return outputs or {}
workflow_run_fields = {
"id": fields.String,
"workflow_id": fields.String,
"status": WorkflowRunStatusField,
"inputs": fields.Raw,
"outputs": WorkflowRunOutputsField,
"error": fields.String,
"total_steps": fields.Integer,
"total_tokens": fields.Integer,
"created_at": TimestampField,
"finished_at": OptionalTimestampField,
"elapsed_time": fields.Float,
}
class WorkflowRunResponse(ResponseModel):
id: str
workflow_id: str
status: str
inputs: dict | list | str | int | float | bool | None = None
outputs: dict = Field(default_factory=dict)
error: str | None = None
total_steps: int | None = None
total_tokens: int | None = None
created_at: int | None = None
finished_at: int | None = None
elapsed_time: float | int | None = None
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
def build_workflow_run_model(api_or_ns: Namespace):
"""Build the workflow run model for the API or Namespace."""
return api_or_ns.model("WorkflowRun", workflow_run_fields)
class WorkflowRunForLogResponse(ResponseModel):
id: str
version: str | None = None
status: str | None = None
triggered_from: str | None = None
error: str | None = None
elapsed_time: float | int | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
exceptions_count: int | None = None
@field_validator("status", "triggered_from", mode="before")
@classmethod
def _normalize_enum(cls, value):
return _enum_value(value)
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: dict | list | str | int | float | bool | None = None
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_from", "created_by_role", mode="before")
@classmethod
def _normalize_enum(cls, value):
return _enum_value(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class WorkflowAppLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowAppLogPartialResponse]
register_schema_models(
service_api_ns,
WorkflowRunResponse,
WorkflowRunForLogResponse,
WorkflowAppLogPartialResponse,
WorkflowAppLogPaginationResponse,
)
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
status = _enum_value(workflow_run.status)
raw_outputs = workflow_run.outputs_dict
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
outputs: dict = {}
elif isinstance(raw_outputs, dict):
outputs = raw_outputs
elif isinstance(raw_outputs, Mapping):
outputs = dict(raw_outputs)
else:
outputs = {}
return WorkflowRunResponse.model_validate(
{
"id": workflow_run.id,
"workflow_id": workflow_run.workflow_id,
"status": status,
"inputs": workflow_run.inputs,
"outputs": outputs,
"error": workflow_run.error,
"total_steps": workflow_run.total_steps,
"total_tokens": workflow_run.total_tokens,
"created_at": workflow_run.created_at,
"finished_at": workflow_run.finished_at,
"elapsed_time": workflow_run.elapsed_time,
}
).model_dump(mode="json")
def _serialize_workflow_log_pagination(pagination) -> dict:
return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json")
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
@service_api_ns.response(
200,
"Workflow run details retrieved successfully",
service_api_ns.models[WorkflowRunResponse.__name__],
)
def get(self, app_model: App, workflow_run_id: str):
"""Get a workflow task running detail.
@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource):
)
if not workflow_run:
raise NotFound("Workflow run not found.")
return workflow_run
return _serialize_workflow_run(workflow_run)
@service_api_ns.route("/workflows/run")
@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
@service_api_ns.response(
200,
"Logs retrieved successfully",
service_api_ns.models[WorkflowAppLogPaginationResponse.__name__],
)
def get(self, app_model: App):
"""Get workflow app logs.
@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource):
created_by_account=args.created_by_account,
)
return workflow_app_log_pagination
return _serialize_workflow_log_pagination(workflow_app_log_pagination)

View File

@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
"""Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict = {}
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = []
result: list[dict[str, Any]] = []
for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result

View File

@ -5,6 +5,7 @@ Web App Human Input Form APIs.
import json
import logging
from datetime import datetime
from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
class FormDefinitionPayload(TypedDict):
form_content: Any
inputs: Any
resolved_default_values: dict[str, str]
user_actions: Any
expiration_time: int
site: NotRequired[dict]
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
"""Return the form payload (optionally with site) as a JSON response."""
definition_payload = form.get_definition().model_dump()
payload = {
payload: FormDefinitionPayload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
@ -92,7 +102,7 @@ class HumanInputFormApi(Resource):
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
service = HumanInputService(db.engine)
# TODO(QuantumGhost): forbid submision for form tokens
# TODO(QuantumGhost): forbid submission for form tokens
# that are only for console.
form = service.get_form_by_token(form_token)

View File

@ -1,7 +1,10 @@
import logging
from flask import make_response, request
from flask_restx import Resource
from jwt import InvalidTokenError
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
@ -20,7 +23,7 @@ from controllers.console.wraps import (
)
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import EmailStr
from libs.helper import EmailStr, extract_remote_ip
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
@ -29,9 +32,11 @@ from libs.token import (
)
from services.account_service import AccountService
from services.app_service import AppService
from services.entities.auth_entities import LoginPayloadBase
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase):
@field_validator("password")
@ -76,14 +81,18 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
payload = LoginPayload.model_validate(web_ns.payload or {})
normalized_email = payload.email.lower()
try:
account = WebAppAuthService.authenticate(payload.email, payload.password)
except services.errors.account.AccountLoginError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError()
except services.errors.account.AccountNotFoundError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
@ -212,21 +221,30 @@ class EmailCodeLoginApi(Resource):
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
if token_data is None:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
if token_data["code"] != payload.code:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(payload.token)
account = WebAppAuthService.get_user_through_email(token_email)
try:
account = WebAppAuthService.get_user_through_email(token_email)
except Unauthorized as exc:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
if not account:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
@ -234,3 +252,12 @@ class EmailCodeLoginApi(Resource):
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Web login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@ -1,5 +1,6 @@
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any
from flask import make_response, request
from flask_restx import Resource
@ -103,21 +104,23 @@ class PassportResource(Resource):
return response
def decode_enterprise_webapp_user_id(jwt_token: str | None):
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
"""
Decode the enterprise user session from the Authorization header.
"""
if not jwt_token:
return None
decoded = PassportService().verify(jwt_token)
decoded: dict[str, Any] = PassportService().verify(jwt_token)
source = decoded.get("token_source")
if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
def exchange_token_for_existing_web_user(
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
):
"""
Exchange a token for an existing web user session.
"""

View File

@ -1,4 +1,4 @@
from typing import cast
from typing import Any, cast
from flask_restx import fields, marshal, marshal_with
from sqlalchemy import select
@ -113,12 +113,12 @@ class AppSiteInfo:
}
def serialize_site(site: Site) -> dict:
def serialize_site(site: Site) -> dict[str, Any]:
"""Serialize Site model using the same schema as AppSiteApi."""
return cast(dict, marshal(site, AppSiteApi.site_fields))
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))

View File

@ -1,399 +0,0 @@
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, cast
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentEntity, AgentLog, AgentResult, ExecutionContext
from core.agent.patterns.strategy_factory import StrategyFactory
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from graphon.file import file_manager
from graphon.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMUsage,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from models.model import Message
logger = logging.getLogger(__name__)
class AgentAppRunner(BaseAgentRunner):
@property
def model_features(self) -> list[ModelFeature]:
llm_model = cast(LargeLanguageModel, self.model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(self.model_instance.model_name, self.model_instance.credentials)
if not model_schema:
return []
return list(model_schema.features or [])
def build_execution_context(self) -> ExecutionContext:
return ExecutionContext(
user_id=self.user_id,
app_id=self.application_generate_entity.app_config.app_id,
conversation_id=self.conversation.id if self.conversation else None,
message_id=self.message.id if self.message else None,
tenant_id=self.tenant_id,
)
def _create_tool_invoke_hook(self, message: Message):
"""
Create a tool invoke hook that uses ToolEngine.agent_invoke.
This hook handles file creation and returns proper meta information.
"""
# Get trace manager from app generate entity
trace_manager = self.application_generate_entity.trace_manager
def tool_invoke_hook(
tool: Tool, tool_args: dict[str, Any], tool_name: str
) -> tuple[str, list[str], ToolInvokeMeta]:
"""Hook that uses agent_invoke for proper file and meta handling."""
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool,
tool_parameters=tool_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=message.id,
conversation_id=self.conversation.id,
)
# Publish files and track IDs
for message_file_id in message_files:
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id),
PublishFrom.APPLICATION_MANAGER,
)
self._current_message_file_ids.append(message_file_id)
return tool_invoke_response, message_files, tool_invoke_meta
return tool_invoke_hook
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run Agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, _ = self._init_prompt_tools()
assert app_config.agent
# Create tool invoke hook for agent_invoke
tool_invoke_hook = self._create_tool_invoke_hook(message)
# Get instruction for ReAct strategy
instruction = self.app_config.prompt_template.simple_prompt_template or ""
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
model_features=self.model_features,
model_instance=self.model_instance,
tools=list(tool_instances.values()),
files=list(self.files),
max_iterations=app_config.agent.max_iteration,
context=self.build_execution_context(),
agent_strategy=self.config.strategy,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Initialize state variables
current_agent_thought_id: str | None = None
has_published_thought = False
current_tool_name: str | None = None
self._current_message_file_ids: list[str] = []
# organize prompt messages
prompt_messages = self._organize_prompt_messages()
# Run strategy
generator = strategy.run(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
stop=app_generate_entity.model_conf.stop,
stream=True,
)
# Consume generator and collect result
result: AgentResult | None = None
try:
while True:
try:
output = next(generator)
except StopIteration as e:
# Generator finished, get the return value
result = e.value
break
if isinstance(output, LLMResultChunk):
# Handle LLM chunk
if current_agent_thought_id and not has_published_thought:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
has_published_thought = True
yield output
elif isinstance(output, AgentLog):
# Handle Agent Log using log_type for type-safe dispatch
if output.status == AgentLog.LogStatus.START:
if output.log_type == AgentLog.LogType.ROUND:
# Start of a new round
message_file_ids: list[str] = []
current_agent_thought_id = self.create_agent_thought(
message_id=message.id,
message="",
tool_name="",
tool_input="",
messages_ids=message_file_ids,
)
has_published_thought = False
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call start - extract data from structured fields
current_tool_name = output.data.get("tool_name", "")
tool_input = output.data.get("tool_args", {})
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=current_tool_name,
tool_input=tool_input,
thought=None,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.status == AgentLog.LogStatus.SUCCESS:
if output.log_type == AgentLog.LogType.THOUGHT:
if current_agent_thought_id is None:
continue
thought_text = output.data.get("thought")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=thought_text,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call finished
tool_output = output.data.get("output")
# Get meta from strategy output (now properly populated)
tool_meta = output.data.get("meta")
# Wrap tool_meta with tool_name as key (required by agent_service)
if tool_meta and current_tool_name:
tool_meta = {current_tool_name: tool_meta}
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=None,
observation=tool_output,
tool_invoke_meta=tool_meta,
answer=None,
messages_ids=self._current_message_file_ids,
)
# Clear message file ids after saving
self._current_message_file_ids = []
current_tool_name = None
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.ROUND:
if current_agent_thought_id is None:
continue
# Round finished - save LLM usage and answer
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
llm_result = output.data.get("llm_result")
final_answer = output.data.get("final_answer")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=llm_result,
observation=None,
tool_invoke_meta=None,
answer=final_answer,
messages_ids=[],
llm_usage=llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
except Exception:
# Re-raise any other exceptions
raise
# Process final result
if isinstance(result, AgentResult):
final_answer = result.text
usage = result.usage or LLMUsage.empty_usage()
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model_name,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""
if not prompt_template:
return prompt_messages or []
prompt_messages = prompt_messages or []
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
return prompt_messages
if not prompt_messages:
return [SystemPromptMessage(content=prompt_template)]
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
# For ReAct strategy, use the agent prompt template
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
prompt_template = self.config.prompt.first_prompt
else:
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@ -1,5 +1,3 @@
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Union
@ -94,79 +92,3 @@ class AgentInvokeMessage(ToolInvokeMessage):
"""
pass
class ExecutionContext(BaseModel):
"""Execution context containing trace and audit information.
Carries IDs and metadata needed for tracing, auditing, and correlation
but not part of the core business logic.
"""
user_id: str | None = None
app_id: str | None = None
conversation_id: str | None = None
message_id: str | None = None
tenant_id: str | None = None
@classmethod
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
return cls(user_id=user_id)
def to_dict(self) -> dict[str, Any]:
return {
"user_id": self.user_id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"message_id": self.message_id,
"tenant_id": self.tenant_id,
}
def with_updates(self, **kwargs) -> "ExecutionContext":
data = self.to_dict()
data.update(kwargs)
return ExecutionContext(**{k: v for k, v in data.items() if k in ExecutionContext.model_fields})
class AgentLog(BaseModel):
"""Structured log entry for agent execution tracing."""
class LogType(StrEnum):
ROUND = "round"
THOUGHT = "thought"
TOOL_CALL = "tool_call"
class LogMetadata(StrEnum):
STARTED_AT = "started_at"
FINISHED_AT = "finished_at"
ELAPSED_TIME = "elapsed_time"
TOTAL_PRICE = "total_price"
TOTAL_TOKENS = "total_tokens"
PROVIDER = "provider"
CURRENCY = "currency"
LLM_USAGE = "llm_usage"
ICON = "icon"
ICON_DARK = "icon_dark"
class LogStatus(StrEnum):
START = "start"
ERROR = "error"
SUCCESS = "success"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
label: str = Field(...)
log_type: LogType = Field(...)
parent_id: str | None = Field(default=None)
error: str | None = Field(default=None)
status: LogStatus = Field(...)
data: Mapping[str, Any] = Field(...)
metadata: Mapping[LogMetadata, Any] = Field(default={})
class AgentResult(BaseModel):
"""Agent execution result."""
text: str = Field(default="")
files: list[Any] = Field(default_factory=list)
usage: Any | None = Field(default=None)
finish_reason: str | None = Field(default=None)

View File

@ -1,7 +1,7 @@
import json
import re
from collections.abc import Generator
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
@ -11,7 +11,7 @@ from core.agent.entities import AgentScratchpadUnit
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
action_name = None

View File

@ -1,19 +0,0 @@
"""Agent patterns module.
This module provides different strategies for agent execution:
- FunctionCallStrategy: Uses native function/tool calling
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
- StrategyFactory: Factory for creating strategies based on model features
"""
from .base import AgentPattern
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
from .strategy_factory import StrategyFactory
__all__ = [
"AgentPattern",
"FunctionCallStrategy",
"ReActStrategy",
"StrategyFactory",
]

View File

@ -1,506 +0,0 @@
"""Base class for agent strategies."""
from __future__ import annotations
import json
import re
import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
from core.model_manager import ModelInstance
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
from graphon.file import File
from graphon.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
PromptMessageTool,
)
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
# Type alias for tool invoke hook
# Returns: (response_content, message_file_ids, tool_invoke_meta)
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
class AgentPattern(ABC):
"""Base class for agent execution strategies."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
):
"""Initialize the agent strategy."""
self.model_instance = model_instance
self.tools = tools
self.context = context
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
self.workflow_call_depth = workflow_call_depth
self.files: list[File] = files
self.tool_invoke_hook = tool_invoke_hook
@abstractmethod
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the agent strategy."""
pass
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
"""Accumulate LLM usage statistics."""
if not total_usage.get("usage"):
# Create a copy to avoid modifying the original
total_usage["usage"] = LLMUsage(
prompt_tokens=delta_usage.prompt_tokens,
prompt_unit_price=delta_usage.prompt_unit_price,
prompt_price_unit=delta_usage.prompt_price_unit,
prompt_price=delta_usage.prompt_price,
completion_tokens=delta_usage.completion_tokens,
completion_unit_price=delta_usage.completion_unit_price,
completion_price_unit=delta_usage.completion_price_unit,
completion_price=delta_usage.completion_price,
total_tokens=delta_usage.total_tokens,
total_price=delta_usage.total_price,
currency=delta_usage.currency,
latency=delta_usage.latency,
)
else:
current: LLMUsage = total_usage["usage"]
current.prompt_tokens += delta_usage.prompt_tokens
current.completion_tokens += delta_usage.completion_tokens
current.total_tokens += delta_usage.total_tokens
current.prompt_price += delta_usage.prompt_price
current.completion_price += delta_usage.completion_price
current.total_price += delta_usage.total_price
def _extract_content(self, content: Any) -> str:
"""Extract text content from message content."""
if isinstance(content, list):
# Content items are PromptMessageContentUnionTypes
text_parts = []
for c in content:
# Check if it's a TextPromptMessageContent (which has data attribute)
if isinstance(c, TextPromptMessageContent):
text_parts.append(c.data)
return "".join(text_parts)
return str(content)
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
"""Check if chunk contains tool calls."""
# LLMResultChunk always has delta attribute
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
def _has_tool_calls_result(self, result: LLMResult) -> bool:
"""Check if result contains tool calls (non-streaming)."""
# LLMResult always has message attribute
return bool(result.message and result.message.tool_calls)
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from streaming chunk."""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
if chunk.delta.message and chunk.delta.message.tool_calls:
for tool_call in chunk.delta.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from non-streaming result."""
tool_calls = []
if result.message and result.message.tool_calls:
for tool_call in result.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_text_from_message(self, message: PromptMessage) -> str:
"""Extract text content from a prompt message."""
# PromptMessage always has content attribute
content = message.content
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract text from content list
text_parts = []
for item in content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return " ".join(text_parts)
return ""
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
"""Get metadata for a tool including provider and icon info."""
from core.tools.tool_manager import ToolManager
metadata: dict[AgentLog.LogMetadata, Any] = {}
if tool_instance.entity and tool_instance.entity.identity:
identity = tool_instance.entity.identity
if identity.provider:
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
# Get icon using ToolManager for proper URL generation
tenant_id = self.context.tenant_id
if tenant_id and identity.provider:
try:
provider_type = tool_instance.tool_provider_type()
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
if isinstance(icon, str):
metadata[AgentLog.LogMetadata.ICON] = icon
elif isinstance(icon, dict):
# Handle icon dict with background/content or light/dark variants
metadata[AgentLog.LogMetadata.ICON] = icon
except Exception:
# Fallback to identity.icon if ToolManager fails
if identity.icon:
metadata[AgentLog.LogMetadata.ICON] = identity.icon
elif identity.icon:
metadata[AgentLog.LogMetadata.ICON] = identity.icon
return metadata
def _create_log(
self,
label: str,
log_type: AgentLog.LogType,
status: AgentLog.LogStatus,
data: dict[str, Any] | None = None,
parent_id: str | None = None,
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
) -> AgentLog:
"""Create a new AgentLog with standard metadata."""
metadata: dict[AgentLog.LogMetadata, Any] = {
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
}
if extra_metadata:
metadata.update(extra_metadata)
return AgentLog(
label=label,
log_type=log_type,
status=status,
data=data or {},
parent_id=parent_id,
metadata=metadata,
)
def _finish_log(
self,
log: AgentLog,
data: dict[str, Any] | None = None,
usage: LLMUsage | None = None,
) -> AgentLog:
"""Finish an AgentLog by updating its status and metadata."""
log.status = AgentLog.LogStatus.SUCCESS
if data is not None:
log.data = data
# Calculate elapsed time
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
finished_at = time.perf_counter()
# Update metadata
log.metadata = {
**log.metadata,
AgentLog.LogMetadata.FINISHED_AT: finished_at,
# Calculate elapsed time in seconds
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
}
# Add usage information if provided
if usage:
log.metadata.update(
{
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
AgentLog.LogMetadata.CURRENCY: usage.currency,
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
AgentLog.LogMetadata.LLM_USAGE: usage,
}
)
return log
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
"""
Replace file references in tool arguments with actual File objects.
Args:
tool_args: Dictionary of tool arguments
Returns:
Updated tool arguments with file references replaced
"""
# Process each argument in the dictionary
processed_args: dict[str, Any] = {}
for key, value in tool_args.items():
processed_args[key] = self._process_file_reference(value)
return processed_args
def _process_file_reference(self, data: Any) -> Any:
"""
Recursively process data to replace file references.
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
Args:
data: The data to process (can be dict, list, str, or other types)
Returns:
Processed data with file references replaced
"""
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
if isinstance(data, dict):
# Process dictionary recursively
return {key: self._process_file_reference(value) for key, value in data.items()}
elif isinstance(data, list):
# Process list recursively
return [self._process_file_reference(item) for item in data]
elif isinstance(data, str):
# Check for single file pattern [File: file_id]
single_match = single_file_pattern.match(data.strip())
if single_match:
file_id = single_match.group(1).strip()
# Find the file in self.files
for file in self.files:
if file.id and str(file.id) == file_id:
return file
# If file not found, return original value
return data
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
multiple_match = multiple_files_pattern.match(data.strip())
if multiple_match:
file_ids_str = multiple_match.group(1).strip()
# Split by comma and strip whitespace
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
# Find all matching files
matched_files: list[File] = []
for file_id in file_ids:
for file in self.files:
if file.id and str(file.id) == file_id:
matched_files.append(file)
break
# Return list of files if any were found, otherwise return original
return matched_files or data
return data
else:
# Return other types as-is
return data
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
"""Create a text chunk for streaming."""
return LLMResultChunk(
model=self.model_instance.model_name,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
usage=None,
),
system_fingerprint="",
)
def _invoke_tool(
self,
tool_instance: Tool,
tool_args: dict[str, Any],
tool_name: str,
) -> tuple[str, list[File], ToolInvokeMeta | None]:
"""
Invoke a tool and collect its response.
Args:
tool_instance: The tool instance to invoke
tool_args: Tool arguments
tool_name: Name of the tool
Returns:
Tuple of (response_content, tool_files, tool_invoke_meta)
"""
# Process tool_args to replace file references with actual File objects
tool_args = self._replace_file_references(tool_args)
# If a tool invoke hook is set, use it instead of generic_invoke
if self.tool_invoke_hook:
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
# The caller (AgentAppRunner) handles file publishing
return response_content, [], tool_invoke_meta
# Default: use generic_invoke for workflow scenarios
# Import here to avoid circular import
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
tool_response = ToolEngine.generic_invoke(
tool=tool_instance,
tool_parameters=tool_args,
user_id=self.context.user_id or "",
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=self.context.app_id,
conversation_id=self.context.conversation_id,
message_id=self.context.message_id,
)
# Collect response and files
response_content = ""
tool_files: list[File] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
response_content += response.message.text
elif response.type == ToolInvokeMessage.MessageType.LINK:
# Handle link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Link: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
# Handle image URL messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
# Handle image link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
# Handle binary file link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
filename = response.meta.get("filename", "file") if response.meta else "file"
response_content += f"[File: {filename} - {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.JSON:
# Handle JSON messages
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# Handle blob messages - convert to text representation
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
mime_type = (
response.meta.get("mime_type", "application/octet-stream")
if response.meta
else "application/octet-stream"
)
size = len(response.message.blob)
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
# Handle variable messages
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
var_name = response.message.variable_name
var_value = response.message.variable_value
if isinstance(var_value, str):
response_content += var_value
else:
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
# Handle blob chunk messages - these are parts of a larger blob
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
# Handle retriever resources messages
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
response_content += response.message.context
elif response.type == ToolInvokeMessage.MessageType.FILE:
# Extract file from meta
if response.meta and "file" in response.meta:
file = response.meta["file"]
if isinstance(file, File):
# Check if file is for model or tool output
if response.meta.get("target") == "self":
# File is for model - add to files for next prompt
self.files.append(file)
response_content += f"File '{file.filename}' has been loaded into your context."
else:
# File is tool output
tool_files.append(file)
return response_content, tool_files, None
def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None:
"""Validate tool arguments against the tool's required parameters.
Checks that all required LLM-facing parameters are present and non-empty
before actual execution, preventing wasted tool invocations when the model
generates calls with missing arguments (e.g. empty ``{}``).
Returns:
Error message if validation fails, None if all required parameters are satisfied.
"""
prompt_tool = tool_instance.to_prompt_message_tool()
required_params: list[str] = prompt_tool.parameters.get("required", [])
if not required_params:
return None
missing = [
p
for p in required_params
if p not in tool_args
or tool_args[p] is None
or (isinstance(tool_args[p], str) and not tool_args[p].strip())
]
if not missing:
return None
return (
f"Missing required parameter(s): {', '.join(missing)}. "
f"Please provide all required parameters before calling this tool."
)
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
"""Find a tool instance by its name."""
for tool in self.tools:
if tool.entity.identity.name == tool_name:
return tool
return None
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
"""Convert tools to prompt message format."""
prompt_tools: list[PromptMessageTool] = []
for tool in self.tools:
prompt_tools.append(tool.to_prompt_message_tool())
return prompt_tools
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
"""Initialize usage tracking with empty usage if not set."""
if "usage" not in llm_usage or llm_usage["usage"] is None:
llm_usage["usage"] = LLMUsage.empty_usage()

View File

@ -1,358 +0,0 @@
"""Function Call strategy implementation.
Implements the Function Call agent pattern where the LLM uses native tool-calling
capability to invoke tools. Includes pre-execution parameter validation that
intercepts invalid calls (e.g. empty arguments) before they reach tool backends,
and avoids counting purely-invalid rounds against the iteration budget.
"""
import json
import logging
from collections.abc import Generator
from typing import Any, Union
from core.agent.entities import AgentLog, AgentResult
from core.tools.entities.tool_entities import ToolInvokeMeta
from graphon.file import File
from graphon.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
)
from .base import AgentPattern
logger = logging.getLogger(__name__)
class FunctionCallStrategy(AgentPattern):
"""Function Call strategy using model's native tool calling capability."""
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the function call agent strategy."""
# Convert tools to prompt format
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
# Initialize tracking
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
function_call_state: bool = True
total_usage: dict[str, LLMUsage | None] = {"usage": None}
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
# Consecutive rounds where ALL tool calls failed parameter validation.
# When this happens the round is "free" (iteration_step not incremented)
# up to a safety cap to prevent infinite loops.
consecutive_validation_failures: int = 0
max_validation_retries: int = 3
while function_call_state and iteration_step <= max_iterations:
function_call_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
yield round_log
# On last iteration, remove tools to force final answer
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model_name} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, LLMUsage | None] = {"usage": None}
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
tools=current_tools,
stop=stop,
stream=stream,
callbacks=[],
)
# Process response
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log
)
messages.append(self._create_assistant_message(response_content, tool_calls))
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update final text if no tool calls (this is likely the final answer)
if not tool_calls:
final_text = response_content
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Process tool calls
tool_outputs: dict[str, str] = {}
all_validation_errors: bool = True
if tool_calls:
function_call_state = True
# Execute tools (with pre-execution parameter validation)
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
output_files.extend(tool_files)
if not is_validation_error:
all_validation_errors = False
else:
all_validation_errors = False
yield self._finish_log(
round_log,
data={
"llm_result": response_content,
"tool_calls": [
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
]
if tool_calls
else [],
"final_answer": final_text if not function_call_state else None,
},
usage=round_usage.get("usage"),
)
# Skip iteration counter when every tool call in this round failed validation,
# giving the model a free retry — but cap retries to prevent infinite loops.
if tool_calls and all_validation_errors:
consecutive_validation_failures += 1
if consecutive_validation_failures >= max_validation_retries:
logger.warning(
"Agent hit %d consecutive validation-only rounds, forcing iteration increment",
consecutive_validation_failures,
)
iteration_step += 1
consecutive_validation_failures = 0
else:
logger.info(
"All tool calls failed validation (attempt %d/%d), not counting iteration",
consecutive_validation_failures,
max_validation_retries,
)
else:
consecutive_validation_failures = 0
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text,
files=output_files,
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
finish_reason=finish_reason,
)
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, LLMUsage | None],
start_log: AgentLog,
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
]:
"""Handle LLM response chunks and extract tool calls and content.
Returns a tuple of (tool_calls, response_content, finish_reason).
"""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
response_content: str = ""
finish_reason: str | None = None
if not isinstance(chunks, LLMResult):
# Streaming response
for chunk in chunks:
# Extract tool calls
if self._has_tool_calls(chunk):
tool_calls.extend(self._extract_tool_calls(chunk))
# Extract content
if chunk.delta.message and chunk.delta.message.content:
response_content += self._extract_content(chunk.delta.message.content)
# Track usage
if chunk.delta.usage:
self._accumulate_usage(llm_usage, chunk.delta.usage)
# Capture finish reason
if chunk.delta.finish_reason:
finish_reason = chunk.delta.finish_reason
yield chunk
else:
# Non-streaming response
result: LLMResult = chunks
if self._has_tool_calls_result(result):
tool_calls.extend(self._extract_tool_calls_result(result))
if result.message and result.message.content:
response_content += self._extract_content(result.message.content)
if result.usage:
self._accumulate_usage(llm_usage, result.usage)
# Convert to streaming format
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
)
yield self._finish_log(
start_log,
data={
"result": response_content,
},
usage=llm_usage.get("usage"),
)
return tool_calls, response_content, finish_reason
def _create_assistant_message(
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
) -> AssistantPromptMessage:
"""Create assistant message with tool calls."""
if tool_calls is None:
return AssistantPromptMessage(content=content)
return AssistantPromptMessage(
content=content or "",
tool_calls=[
AssistantPromptMessage.ToolCall(
id=tc[0],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
)
for tc in tool_calls
],
)
def _handle_tool_call(
self,
tool_name: str,
tool_args: dict[str, Any],
tool_call_id: str,
messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]:
"""Handle a single tool call and return response with files, meta, and validation status.
Validates required parameters before execution. When validation fails the tool
is never invoked — a synthetic error is fed back to the model so it can self-correct
without consuming a real iteration.
Returns:
(response_content, tool_files, tool_invoke_meta, is_validation_error).
``is_validation_error`` is True when the call was rejected due to missing
required parameters, allowing the caller to skip the iteration counter.
"""
# Find tool
tool_instance = self._find_tool_by_name(tool_name)
if not tool_instance:
raise ValueError(f"Tool {tool_name} not found")
# Get tool metadata (provider, icon, etc.)
tool_metadata = self._get_tool_metadata(tool_instance)
# Create tool call log
tool_call_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
extra_metadata=tool_metadata,
)
yield tool_call_log
# Validate required parameters before execution to avoid wasted invocations
validation_error = self._validate_tool_args(tool_instance, tool_args)
if validation_error:
tool_call_log.status = AgentLog.LogStatus.ERROR
tool_call_log.error = validation_error
tool_call_log.data = {**tool_call_log.data, "error": validation_error}
yield tool_call_log
messages.append(ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name))
return validation_error, [], None, True
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
yield self._finish_log(
tool_call_log,
data={
**tool_call_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
final_content = response_content or "Tool executed successfully"
# Add tool response to messages
messages.append(
ToolPromptMessage(
content=final_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return response_content, tool_files, tool_invoke_meta, False
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_call_log.status = AgentLog.LogStatus.ERROR
tool_call_log.error = error_message
tool_call_log.data = {
**tool_call_log.data,
"error": error_message,
}
yield tool_call_log
# Add error message to conversation
error_content = f"Tool execution failed: {error_message}"
messages.append(
ToolPromptMessage(
content=error_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return error_content, [], None, False

View File

@ -1,418 +0,0 @@
"""ReAct strategy implementation."""
from __future__ import annotations
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Union
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.model_manager import ModelInstance
from graphon.file import File
from graphon.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
SystemPromptMessage,
)
from .base import AgentPattern, ToolInvokeHook
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class ReActStrategy(AgentPattern):
"""ReAct strategy using reasoning and acting approach."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
):
"""Initialize the ReAct strategy with instruction support."""
super().__init__(
model_instance=model_instance,
tools=tools,
context=context,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
files=files,
tool_invoke_hook=tool_invoke_hook,
)
self.instruction = instruction
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the ReAct agent strategy."""
# Initialize tracking
agent_scratchpad: list[AgentScratchpadUnit] = []
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
react_state: bool = True
total_usage: dict[str, Any] = {"usage": None}
output_files: list[File] = [] # Track files produced by tools
final_text: str = ""
finish_reason: str | None = None
# Add "Observation" to stop sequences
if "Observation" not in stop:
stop = stop.copy()
stop.append("Observation")
while react_state and iteration_step <= max_iterations:
react_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
yield round_log
# Build prompt with/without tools based on iteration
include_tools = iteration_step < max_iterations
current_messages = self._build_prompt_with_react_format(
prompt_messages, agent_scratchpad, include_tools, self.instruction
)
model_log = self._create_log(
label=f"{self.model_instance.model_name} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, Any] = {"usage": None}
# Use current messages directly (files are handled by base class if needed)
messages_to_use = current_messages
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages_to_use,
model_parameters=model_parameters,
stop=stop,
stream=stream,
callbacks=[],
)
# Process response
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log, current_messages
)
agent_scratchpad.append(scratchpad)
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Check if we have an action to execute
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
react_state = True
# Execute tool
observation, tool_files = yield from self._handle_tool_call(
scratchpad.action, current_messages, round_log
)
scratchpad.observation = observation
# Track files produced by tools
output_files.extend(tool_files)
# Add observation to scratchpad for display
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
else:
# Extract final answer
if scratchpad.action and scratchpad.action.action_input:
final_answer = scratchpad.action.action_input
if isinstance(final_answer, dict):
final_answer = json.dumps(final_answer, ensure_ascii=False)
final_text = str(final_answer)
elif scratchpad.thought:
# If no action but we have thought, use thought as final answer
final_text = scratchpad.thought
yield self._finish_log(
round_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
"observation": scratchpad.observation or None,
"final_answer": final_text if not react_state else None,
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
)
def _build_prompt_with_react_format(
self,
original_messages: list[PromptMessage],
agent_scratchpad: list[AgentScratchpadUnit],
include_tools: bool = True,
instruction: str = "",
) -> list[PromptMessage]:
"""Build prompt messages with ReAct format."""
# Copy messages to avoid modifying original
messages = list(original_messages)
# Find and update the system prompt that should already exist
system_prompt_found = False
for i, msg in enumerate(messages):
if isinstance(msg, SystemPromptMessage):
system_prompt_found = True
# The system prompt from frontend already has the template, just replace placeholders
# Format tools
tools_str = ""
tool_names = []
if include_tools and self.tools:
# Convert tools to prompt message tools format
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
tool_names = [tool.name for tool in prompt_tools]
# Format tools as JSON for comprehensive information
from graphon.model_runtime.utils.encoders import jsonable_encoder
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
else:
tools_str = "No tools available"
tool_names_str = ""
# Replace placeholders in the existing system prompt
updated_content = msg.content
assert isinstance(updated_content, str)
updated_content = updated_content.replace("{{instruction}}", instruction)
updated_content = updated_content.replace("{{tools}}", tools_str)
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
# Create new SystemPromptMessage with updated content
messages[i] = SystemPromptMessage(content=updated_content)
break
# If no system prompt found, that's unexpected but add scratchpad anyway
if not system_prompt_found:
# This shouldn't happen if frontend is working correctly
pass
# Format agent scratchpad
scratchpad_str = ""
if agent_scratchpad:
scratchpad_parts: list[str] = []
for unit in agent_scratchpad:
if unit.thought:
scratchpad_parts.append(f"Thought: {unit.thought}")
if unit.action_str:
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
if unit.observation:
scratchpad_parts.append(f"Observation: {unit.observation}")
scratchpad_str = "\n".join(scratchpad_parts)
# If there's a scratchpad, append it to the last message
if scratchpad_str:
messages.append(AssistantPromptMessage(content=scratchpad_str))
return messages
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, Any],
model_log: AgentLog,
current_messages: list[PromptMessage],
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[AgentScratchpadUnit, str | None],
]:
"""Handle LLM response chunks and extract action/thought.
Returns a tuple of (scratchpad_unit, finish_reason).
"""
usage_dict: dict[str, Any] = {}
# Convert non-streaming to streaming format if needed
if isinstance(chunks, LLMResult):
result = chunks
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
finish_reason=None,
),
system_fingerprint=result.system_fingerprint or "",
)
streaming_chunks = result_to_chunks()
else:
streaming_chunks = chunks
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
# Initialize scratchpad unit
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
finish_reason: str | None = None
# Process chunks
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
# Action detected
action_str = json.dumps(chunk.model_dump())
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
scratchpad.action_str = action_str
scratchpad.action = chunk
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
else:
# Text chunk
chunk_text = str(chunk)
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
scratchpad.thought = (scratchpad.thought or "") + chunk_text
yield self._create_text_chunk(chunk_text, current_messages)
# Update usage
if usage_dict.get("usage"):
if llm_usage.get("usage"):
self._accumulate_usage(llm_usage, usage_dict["usage"])
else:
llm_usage["usage"] = usage_dict["usage"]
# Clean up thought
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
# Finish model log
yield self._finish_log(
model_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
},
usage=llm_usage.get("usage"),
)
return scratchpad, finish_reason
def _handle_tool_call(
self,
action: AgentScratchpadUnit.Action,
prompt_messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
"""Handle tool call and return observation with files."""
tool_name = action.action_name
tool_args: dict[str, Any] | str = action.action_input
# Find tool instance first to get metadata
tool_instance = self._find_tool_by_name(tool_name)
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
# Start tool log with tool metadata
tool_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
extra_metadata=tool_metadata,
)
yield tool_log
if not tool_instance:
# Finish tool log with error
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"error": f"Tool {tool_name} not found",
},
)
return f"Tool {tool_name} not found", []
# Ensure tool_args is a dict
tool_args_dict: dict[str, Any]
if isinstance(tool_args, str):
try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
tool_args_dict = {"input": tool_args}
elif not isinstance(tool_args, dict):
tool_args_dict = {"input": str(tool_args)}
else:
tool_args_dict = tool_args
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
# Finish tool log
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
return response_content or "Tool executed successfully", tool_files
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_log.status = AgentLog.LogStatus.ERROR
tool_log.error = error_message
tool_log.data = {
**tool_log.data,
"error": error_message,
}
yield tool_log
return f"Tool execution failed: {error_message}", []

View File

@ -1,108 +0,0 @@
"""Strategy factory for creating agent strategies."""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.agent.entities import AgentEntity, ExecutionContext
from core.model_manager import ModelInstance
from graphon.file.models import File
from graphon.model_runtime.entities.model_entities import ModelFeature
from .base import AgentPattern, ToolInvokeHook
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class StrategyFactory:
"""Factory for creating agent strategies based on model features."""
# Tool calling related features
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
@staticmethod
def create_strategy(
model_features: list[ModelFeature],
model_instance: ModelInstance,
context: ExecutionContext,
tools: list[Tool],
files: list[File],
max_iterations: int = 10,
workflow_call_depth: int = 0,
agent_strategy: AgentEntity.Strategy | None = None,
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
) -> AgentPattern:
"""
Create an appropriate strategy based on model features.
Args:
model_features: List of model features/capabilities
model_instance: Model instance to use
context: Execution context containing trace/audit information
tools: Available tools
files: Available files
max_iterations: Maximum iterations for the strategy
workflow_call_depth: Depth of workflow calls
agent_strategy: Optional explicit strategy override
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
instruction: Optional instruction for ReAct strategy
Returns:
AgentStrategy instance
"""
# If explicit strategy is provided and it's Function Calling, try to use it if supported
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
# Fallback to ReAct if FC is requested but not supported
# If explicit strategy is Chain of Thought (ReAct)
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Default auto-selection logic
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
# Model supports native function calling
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
else:
# Use ReAct strategy for models without function calling
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)

View File

@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: dict | None = None
output_schema: dict[str, Any] | None = None
features: list[AgentFeature] | None = None
meta_version: str | None = None
# pydantic configs

View File

@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager:
@classmethod
def validate_and_set_defaults(
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
) -> tuple[dict[str, Any], list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}

View File

@ -138,7 +138,9 @@ class DatasetConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for dataset feature
@ -172,7 +174,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
"""
Extract dataset config for legacy compatibility

View File

@ -41,7 +41,7 @@ class ModelConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for model config
@ -108,7 +108,7 @@ class ModelConfigManager:
return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict):
def validate_model_completion_params(cls, cp: dict[str, Any]):
# model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")

View File

@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate pre_prompt and set defaults for prompt feature
depending on the config['model']
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict):
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
"""
Validate post_prompt and set defaults for prompt feature

View File

@ -1,5 +1,5 @@
import re
from typing import cast
from typing import Any, cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
return variable_entities, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for user input form
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
return config, related_config_keys
@classmethod
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for user input form
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
return config, ["user_input_form"]
@classmethod
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_external_data_tools_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for external data fetch feature

View File

@ -30,7 +30,7 @@ class FileUploadConfigManager:
return FileUploadConfig.model_validate(file_upload_dict)
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for file upload feature

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, ValidationError
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError:

View File

@ -1,6 +1,9 @@
from typing import Any
class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[str, list]:
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
return opening_statement, suggested_questions_list
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for opening statement feature

View File

@ -1,6 +1,9 @@
from typing import Any
class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict:
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
return show_retrieve_source
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for retriever resource feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SpeechToTextConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
return speech_to_text
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for speech to text feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
return suggested_questions_after_answer
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for suggested questions feature

View File

@ -1,9 +1,11 @@
from typing import Any
from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechConfigManager:
@classmethod
def convert(cls, config: dict):
def convert(cls, config: dict[str, Any]):
"""
Convert model config to model config
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
return text_to_speech
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for text to speech feature

View File

@ -177,14 +177,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
# Resolve parent_message_id for thread continuity
if invoke_from == InvokeFrom.SERVICE_API:
parent_message_id: str | None = UUID_NIL
else:
parent_message_id = args.get("parent_message_id")
if not parent_message_id and conversation:
parent_message_id = self._resolve_latest_message_id(conversation.id)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
@ -196,7 +188,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=parent_message_id,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
@ -697,17 +689,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e
@staticmethod
def _resolve_latest_message_id(conversation_id: str) -> str | None:
"""Auto-resolve parent_message_id to the latest message when client doesn't provide one."""
from sqlalchemy import select
stmt = (
select(Message.id)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at.desc())
.limit(1)
)
latest_id = db.session.scalar(stmt)
return str(latest_id) if latest_id else None

View File

@ -57,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
) -> Generator[dict[str, Any] | str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -88,7 +88,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
) -> Generator[dict[str, Any] | str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -1,12 +1,15 @@
import logging
from typing import cast
from graphon.model_runtime.entities.model_entities import ModelFeature
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import select
from core.agent.agent_app_runner import AgentAppRunner
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
@ -189,8 +192,24 @@ class AgentChatAppRunner(AppRunner):
message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()
runner = AgentAppRunner(
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
runner_cls = FunctionCallAgentRunner
else:
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
runner = runner_cls(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
conversation=conversation_result,

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -24,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[dict | str, Any, None]:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
yield from cls.convert_stream_full_response(response)
return _generate_full_response()
@ -33,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[dict | str, Any, None]:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
return _generate_simple_response()
@ -52,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
raise NotImplementedError
@classmethod

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -1,53 +0,0 @@
"""Legacy Response Adapter for transparent upgrade.
When old apps (chat/completion/agent-chat) run through the Agent V2
workflow engine via transparent upgrade, the SSE events are in workflow
format (workflow_started, node_started, etc.). This adapter filters out
workflow-specific events and passes through only the events that old
clients expect (message, message_end, etc.).
"""
from __future__ import annotations
import json
import logging
from collections.abc import Generator
logger = logging.getLogger(__name__)
WORKFLOW_ONLY_EVENTS = frozenset({
"workflow_started",
"workflow_finished",
"node_started",
"node_finished",
"iteration_started",
"iteration_next",
"iteration_completed",
})
def adapt_workflow_stream_for_legacy(
stream: Generator[str, None, None],
) -> Generator[str, None, None]:
"""Filter workflow-specific SSE events from a streaming response.
Passes through message, message_end, agent_log, error, ping events.
Suppresses workflow_started, workflow_finished, node_started, node_finished.
This makes the SSE stream look more like what old easy-UI apps produce,
while still carrying the actual LLM response content.
"""
for chunk in stream:
if not chunk or not chunk.strip():
yield chunk
continue
try:
if chunk.startswith("data: "):
data = json.loads(chunk[6:])
event = data.get("event", "")
if event in WORKFLOW_ONLY_EVENTS:
continue
yield chunk
except (json.JSONDecodeError, TypeError):
yield chunk

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -1,3 +1,5 @@
from typing import Any
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager):
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
) -> dict[str, Any]:
"""
Validate for pipeline config

View File

@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: dict | None = None,
next_page_parameters: dict[str, Any] | None = None,
):
"""
Get files in a folder.

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -682,15 +682,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
elif invoke_from == InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
elif invoke_from == InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
else:
# not save log for debugging
return
match invoke_from:
case InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
case InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
case InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
case InvokeFrom.DEBUGGER | InvokeFrom.TRIGGER | InvokeFrom.PUBLISHED_PIPELINE | InvokeFrom.VALIDATION:
# not save log for debugging
return
if not workflow_run_id:
return

View File

@ -146,6 +146,8 @@ class WorkflowBasedAppRunner:
call_depth=0,
)
# Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,

View File

@ -1,72 +0,0 @@
"""
LLM Generation Detail entities.
Defines the structure for storing and transmitting LLM generation details
including reasoning content, tool calls, and their sequence.
"""
from typing import Literal
from pydantic import BaseModel, Field
class ContentSegment(BaseModel):
"""Represents a content segment in the generation sequence."""
type: Literal["content"] = "content"
start: int = Field(..., description="Start position in the text")
end: int = Field(..., description="End position in the text")
class ReasoningSegment(BaseModel):
"""Represents a reasoning segment in the generation sequence."""
type: Literal["reasoning"] = "reasoning"
index: int = Field(..., description="Index into reasoning_content array")
class ToolCallSegment(BaseModel):
"""Represents a tool call segment in the generation sequence."""
type: Literal["tool_call"] = "tool_call"
index: int = Field(..., description="Index into tool_calls array")
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
class ToolCallDetail(BaseModel):
"""Represents a tool call with its arguments and result."""
id: str = Field(default="", description="Unique identifier for the tool call")
name: str = Field(..., description="Name of the tool")
arguments: str = Field(default="", description="JSON string of tool arguments")
result: str = Field(default="", description="Result from the tool execution")
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
icon: str | dict | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
class LLMGenerationDetailData(BaseModel):
"""
Domain model for LLM generation detail.
Contains the structured data for reasoning content, tool calls,
and their display sequence.
"""
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
def is_empty(self) -> bool:
"""Check if there's any meaningful generation detail."""
return not self.reasoning_content and not self.tool_calls
def to_response_dict(self) -> dict:
"""Convert to dictionary for API response."""
return {
"reasoning_content": self.reasoning_content,
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
"sequence": [seg.model_dump() for seg in self.sequence],
}

View File

@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict | None = None
extras: dict[str, Any] | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict | None = None
extras: dict[str, Any] | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus

Some files were not shown because too many files have changed in this diff Show More