Compare commits

..

63 Commits

Author SHA1 Message Date
5874b920b2 fix: code owners 2025-11-28 14:36:30 +08:00
c51ab6ec37 fix: the consistency of the go-to-anything interaction (#28857) 2025-11-28 14:29:15 +08:00
1fc2255219 test: add comprehensive unit tests for EndUserService (#28840) 2025-11-28 14:22:19 +08:00
037389137d feat: complete test script of indexing runner (#28828)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-28 14:18:59 +08:00
8cd3e84c06 chore: bump dify plugin version in docker.middleware (#28847) 2025-11-28 13:55:13 +08:00
b3c6ac1430 chore: assign code owners to frontend and backend modules in CODEOWNERS (#28713) 2025-11-28 12:42:58 +08:00
68bb97919a feat: add comprehensive unit tests for MessageService (#28837)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-28 12:36:15 +08:00
f268d7c7be feat: complete test script of website crawl (#28826) 2025-11-28 12:34:27 +08:00
d695a79ba1 test: add comprehensive unit tests for DocumentIndexingTaskProxy (#28830)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-28 12:30:54 +08:00
cd5a745bd2 feat: complete test script of notion provider (#28833) 2025-11-28 12:30:45 +08:00
51e5f422c4 test: add comprehensive unit tests for VectorService and Vector classes (#28834)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-28 12:30:02 +08:00
ec3b2b40c2 test: add comprehensive unit tests for FeedbackService (#28771)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-28 11:33:56 +08:00
67ae3e9253 docker: use COPY --chown in api Dockerfile to avoid adding layers by explicit chown calls (#28756) 2025-11-28 11:33:06 +08:00
d38e3b7792 test: add unit tests for document service status management (#28804)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-28 11:25:36 +08:00
43d27edef2 feat: complete test script of embedding service (#28817)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-28 11:24:30 +08:00
94b87eac72 feat: add comprehensive unit tests for provider models (#28702)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-28 11:24:20 +08:00
yyh
fd31af6012 fix(ci): use dynamic branch name for i18n workflow to prevent race condition (#28823)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-28 11:23:28 +08:00
yyh
228deccec2 chore: update packageManager version in package.json to pnpm@10.24.0 (#28820) 2025-11-28 11:23:20 +08:00
639f1d31f7 feat: complete test script of text splitter (#28813)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-28 11:22:52 +08:00
ec786fe236 test: add unit tests for document service validation and configuration (#28810)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-28 11:21:45 +08:00
fe3a6ef049 feat: complete test script of reranker (#28806)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-28 11:21:35 +08:00
8b761319f6 Refactor workflow nodes to use generic node_data (#28782) 2025-11-27 20:46:56 +08:00
002d8769b0 chore: translate i18n files and update type definitions (#28784)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-27 20:28:17 +08:00
5aba111297 Feat zen mode (#28794) 2025-11-27 20:10:50 +08:00
dc9b3a7e03 refactor: rename VariableAssignerNodeData to VariableAggregatorNodeData (#28780) 2025-11-27 17:45:48 +08:00
5f2e0d6347 pref: reduce next step components reRender (#28783) 2025-11-27 17:12:00 +08:00
1f72571c06 edit analyze-component (#28781)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: 姜涵煦 <hanxujiang@jianghanxudeMacBook-Pro.local>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 16:54:44 +08:00
820925a866 feat(workflow): workflow as tool output schema (#26241)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Novice <novice12185727@gmail.com>
2025-11-27 16:50:48 +08:00
299bd351fd perf: reduce reRender in candidate node (#28776) 2025-11-27 15:57:36 +08:00
13bf6547ee Refactor: centralize node data hydration (#27771) 2025-11-27 15:41:56 +08:00
1b733abe82 feat: creates logs immediately when workflows start (not at completion) (#28701) 2025-11-27 15:22:33 +08:00
5782e26ab2 test: add unit tests for dataset service update/delete operations (#28757)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 15:01:43 +08:00
38d329e75a test: add unit tests for dataset permission service (#28760)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 15:00:55 +08:00
58f448a926 chore: remove outdated model config doc (#28765) 2025-11-27 14:40:06 +08:00
7a7fea40d9 feat: complete test script of dataset retrieval (#28762)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 14:39:33 +08:00
0309545ff1 Feat/test script of workflow service (#28726)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-27 11:23:55 +08:00
6deabfdad3 Use naive_utc_now in graph engine tests (#28735)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 11:23:20 +08:00
f9b4c31344 fix: MCP tool time configuration not work (#28740) 2025-11-27 11:22:49 +08:00
8d8800e632 upgrade docker compose milvus version to 2.6.0 to fix installation error (#26618)
Co-authored-by: crazywoola <427733928@qq.com>
2025-11-27 11:01:14 +08:00
4ca4493084 Add comprehensive unit tests for MetadataService (dataset metadata CRUD operations and filtering) (#28748)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 11:00:10 +08:00
7efa0df1fd Add comprehensive API/controller tests for dataset endpoints (list, create, update, delete, documents, segments, hit testing, external datasets) (#28750)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 10:59:17 +08:00
b786e101e5 fix: querying and setting the system default model (#28743) 2025-11-27 11:58:35 +09:00
09a8046b10 fix: querying webhook trigger issue (#28753) 2025-11-27 10:56:21 +08:00
2f6b3f1c5f hotfix: fix _extract_filename for rfc 5987 (#26230)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-11-27 10:54:00 +08:00
2551f6f279 feat: add APP_DEFAULT_ACTIVE_REQUESTS as the default value for APP_AC… (#26930) 2025-11-27 10:51:48 +08:00
01afa56166 chore: enhance the test script of current billing service (#28747)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 10:37:24 +08:00
5815950092 add unit tests for iteration node (#28719)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 10:36:47 +08:00
766e16b26f add unit tests for code node (#28717) 2025-11-27 10:36:37 +08:00
0fdb4e7c12 chore: enhance the test script of conversation service (#28739)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 09:57:52 +08:00
64babb35e2 feat: Add comprehensive unit tests for DatasetCollectionBindingService (dataset collection binding methods) (#28724)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-27 09:55:42 +08:00
38522e5dfa fix: use default_factory for callable defaults in ORM dataclasses (#28730) 2025-11-27 09:39:49 +09:00
4ccc150fd1 test: add comprehensive unit tests for ExternalDatasetService (external knowledge API integration) (#28716)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-26 23:33:46 +08:00
a4c57017d5 add: badges (#28722) 2025-11-26 23:30:41 +08:00
b2a7cec644 add unit tests for template transform node (#28595)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-26 22:50:20 +08:00
ddc5cbe865 feat: complete test script of dataset service (#28710)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-26 22:48:08 +08:00
1e23957657 fix(ops): add streaming metrics and LLM span for agent-chat traces (#28320)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-26 22:45:20 +08:00
2731b04ff9 Pydantic models (#28697)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-26 22:44:14 +08:00
e8ca80a61a add unit tests for list operator node (#28597)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-26 22:43:30 +08:00
e76129b5a4 test: add comprehensive unit tests for HitTestingService Fix: #28667 (#28668)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-26 22:42:58 +08:00
6635ea62c2 fix: change existing node to a webhook node raise 404 (#28686) 2025-11-26 22:41:52 +08:00
6b8c649876 fix: prevent auto-scrolling from stopping in chat (#28690)
Signed-off-by: Yuichiro Utsumi <utsumi.yuichiro@fujitsu.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-26 22:39:29 +08:00
af587f3869 chore: update packageManager version to pnpm@10.23.0 (#28708) 2025-11-26 22:37:05 +08:00
1c1f124891 Enhanced GraphEngine Pause Handling (#28196)
This commit: 

1. Convert `pause_reason` to `pause_reasons` in `GraphExecution` and relevant classes. Change the field from a scalar value to a list that can contain multiple `PauseReason` objects, ensuring all pause events are properly captured.
2. Introduce a new `WorkflowPauseReason` model to record reasons associated with a specific `WorkflowPause`.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-26 19:59:34 +08:00
287 changed files with 39416 additions and 5761 deletions

226
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,226 @@
# CODEOWNERS
# This file defines code ownership for the Dify project.
# Each line is a file pattern followed by one or more owners.
# Owners can be @username, @org/team-name, or email addresses.
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
* @crazywoola @laipz8200 @Yeuoly
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
# Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
api/core/workflow/graph/ @laipz8200 @QuantumGhost
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
api/core/workflow/nodes/agent/ Nov1c444
api/core/workflow/nodes/iteration/ Nov1c444
api/core/workflow/nodes/loop/ Nov1c444
api/core/workflow/nodes/llm/ Nov1c444
# Backend - RAG (Retrieval Augmented Generation)
api/core/rag/ @JohnJyong
api/services/rag_pipeline/ @JohnJyong
api/services/dataset_service.py @JohnJyong
api/services/knowledge_service.py @JohnJyong
api/services/external_knowledge_service.py @JohnJyong
api/services/hit_testing_service.py @JohnJyong
api/services/metadata_service.py @JohnJyong
api/services/vector_service.py @JohnJyong
api/services/entities/knowledge_entities/ @JohnJyong
api/services/entities/external_knowledge_entities/ @JohnJyong
api/controllers/console/datasets/ @JohnJyong
api/controllers/service_api/dataset/ @JohnJyong
api/models/dataset.py @JohnJyong
api/tasks/rag_pipeline/ @JohnJyong
api/tasks/add_document_to_index_task.py @JohnJyong
api/tasks/batch_clean_document_task.py @JohnJyong
api/tasks/clean_document_task.py @JohnJyong
api/tasks/clean_notion_document_task.py @JohnJyong
api/tasks/document_indexing_task.py @JohnJyong
api/tasks/document_indexing_sync_task.py @JohnJyong
api/tasks/document_indexing_update_task.py @JohnJyong
api/tasks/duplicate_document_indexing_task.py @JohnJyong
api/tasks/recover_document_indexing_task.py @JohnJyong
api/tasks/remove_document_from_index_task.py @JohnJyong
api/tasks/retry_document_indexing_task.py @JohnJyong
api/tasks/sync_website_document_indexing_task.py @JohnJyong
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
api/tasks/create_segment_to_index_task.py @JohnJyong
api/tasks/delete_segment_from_index_task.py @JohnJyong
api/tasks/disable_segment_from_index_task.py @JohnJyong
api/tasks/disable_segments_from_index_task.py @JohnJyong
api/tasks/enable_segment_to_index_task.py @JohnJyong
api/tasks/enable_segments_to_index_task.py @JohnJyong
api/tasks/clean_dataset_task.py @JohnJyong
api/tasks/deal_dataset_index_update_task.py @JohnJyong
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins
api/core/plugin/ @Mairuis @Yeuoly @Stream29
api/services/plugin/ @Mairuis @Yeuoly @Stream29
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook
api/controllers/trigger/ @Mairuis @Yeuoly
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
api/core/trigger/ @Mairuis @Yeuoly
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
api/services/trigger/ @Mairuis @Yeuoly
api/models/trigger.py @Mairuis @Yeuoly
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/libs/schedule_utils.py @Mairuis @Yeuoly
api/services/workflow/scheduler.py @Mairuis @Yeuoly
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow
api/services/async_workflow_service.py @Mairuis @Yeuoly
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing
api/services/billing_service.py @hj24 @zyssyz123
api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise
api/configs/enterprise/ @GarfieldDai @GareArc
api/services/enterprise/ @GarfieldDai @GareArc
api/services/feature_service.py @GarfieldDai @GareArc
api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200
# Frontend
web/ @iamjoel
# Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh
web/app/components/app/configuration/ @iamjoel @zxhlyh
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat
web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion
web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation
web/app/components/apps/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation
web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
web/app/components/app/log/ @JzoNgKVO @iamjoel
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation
web/app/components/datasets/list/ @iamjoel @WTW0313
web/app/components/datasets/create/ @iamjoel @WTW0313
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
web/app/components/rag-pipeline/ @iamjoel @WTW0313
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings
web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins
web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools
web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration
web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel
web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication
web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control
web/app/(shareLayout)/components/ @douxc @iamjoel
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page
web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics
web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components
web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks
web/utils/classnames.ts @iamjoel @zxhlyh
web/utils/time.ts @iamjoel @zxhlyh
web/utils/format.ts @iamjoel @zxhlyh
web/utils/clipboard.ts @iamjoel @zxhlyh
web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education
web/app/components/billing/ @iamjoel @zxhlyh
web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh

View File

@ -77,12 +77,15 @@ jobs:
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Update i18n files and type definitions based on en-US changes
title: 'chore: translate i18n files and update type definitions'
commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files and update type definitions'
body: |
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
**Triggered by:** ${{ github.sha }}
**Changes included:**
- Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety
branch: chore/automated-i18n-updates
branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true

View File

@ -36,6 +36,12 @@
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Health Score" src="https://insights.linuxfoundation.org/api/badge/health-score?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Contributors" src="https://insights.linuxfoundation.org/api/badge/contributors?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Active Contributors" src="https://insights.linuxfoundation.org/api/badge/active-contributors?project=langgenius-dify"></a>
</p>
<p align="center">

View File

@ -540,6 +540,7 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration

View File

@ -16,6 +16,7 @@ layers =
graph
nodes
node_events
runtime
entities
containers =
core.workflow

View File

@ -48,6 +48,12 @@ ENV PYTHONIOENCODING=utf-8
WORKDIR /app/api
# Create non-root user
ARG dify_uid=1001
RUN groupadd -r -g ${dify_uid} dify && \
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
chown -R dify:dify /app
RUN \
apt-get update \
# Install dependencies
@ -69,7 +75,7 @@ RUN \
# Copy Python environment and packages
ENV VIRTUAL_ENV=/app/api/.venv
COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
@ -78,24 +84,20 @@ RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \
&& chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
# Copy source code
COPY . /app/api/
COPY --chown=dify:dify . /app/api/
# Copy entrypoint
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
# Prepare entrypoint script
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
# Create non-root user and set permissions
RUN groupadd -r -g 1001 dify && \
useradd -r -u 1001 -g 1001 -s /bin/bash dify && \
mkdir -p /home/dify && \
chown -R 1001:1001 /app /home/dify ${TIKTOKEN_CACHE_DIR} /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA=${COMMIT_SHA}
ENV NLTK_DATA=/usr/local/share/nltk_data
USER 1001
USER dify
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]

View File

@ -73,6 +73,10 @@ class AppExecutionConfig(BaseSettings):
description="Maximum allowed execution time for the application in seconds",
default=1200,
)
APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Default number of concurrent active requests per app (0 for unlimited)",
default=0,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0,

View File

@ -1,6 +1,8 @@
import logging
from flask_restx import Resource, marshal_with, reqparse
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -18,16 +20,30 @@ from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
class Parser(BaseModel):
node_id: str
class ParserEnable(BaseModel):
trigger_id: str
enable_trigger: bool
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/workflows/triggers/webhook")
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required
@ -35,9 +51,9 @@ class WebhookTriggerApi(Resource):
@marshal_with(webhook_trigger_fields)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = parser.parse_args()
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
node_id = str(args["node_id"])
node_id = args.node_id
with Session(db.engine) as session:
# Get webhook trigger for this app and node
@ -96,16 +112,9 @@ class AppTriggersApi(Resource):
return {"data": triggers}
parser_enable = (
reqparse.RequestParser()
.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource):
@console_ns.expect(parser_enable)
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
@ -114,12 +123,11 @@ class AppTriggerEnableApi(Resource):
@marshal_with(trigger_fields)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = parser_enable.parse_args()
args = ParserEnable.model_validate(console_ns.payload)
assert current_user.current_tenant_id is not None
trigger_id = args["trigger_id"]
trigger_id = args.trigger_id
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
@ -134,7 +142,7 @@ class AppTriggerEnableApi(Resource):
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)

View File

@ -1,8 +1,10 @@
from datetime import datetime
from typing import Literal
import pytz
from flask import request
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -42,20 +44,198 @@ from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _init_parser():
parser = reqparse.RequestParser()
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
return parser
class AccountInitPayload(BaseModel):
interface_language: str
timezone: str
invitation_code: str | None = None
@field_validator("interface_language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str) -> str:
return timezone(value)
class AccountNamePayload(BaseModel):
name: str = Field(min_length=3, max_length=30)
class AccountAvatarPayload(BaseModel):
avatar: str
class AccountInterfaceLanguagePayload(BaseModel):
interface_language: str
@field_validator("interface_language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
class AccountInterfaceThemePayload(BaseModel):
interface_theme: Literal["light", "dark"]
class AccountTimezonePayload(BaseModel):
timezone: str
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str) -> str:
return timezone(value)
class AccountPasswordPayload(BaseModel):
password: str | None = None
new_password: str
repeat_new_password: str
@model_validator(mode="after")
def check_passwords_match(self) -> "AccountPasswordPayload":
if self.new_password != self.repeat_new_password:
raise RepeatPasswordNotMatchError()
return self
class AccountDeletePayload(BaseModel):
token: str
code: str
class AccountDeletionFeedbackPayload(BaseModel):
email: str
feedback: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class EducationActivatePayload(BaseModel):
token: str
institution: str
role: str
class EducationAutocompleteQuery(BaseModel):
keywords: str
page: int = 0
limit: int = 20
class ChangeEmailSendPayload(BaseModel):
email: str
language: str | None = None
phase: str | None = None
token: str | None = None
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailValidityPayload(BaseModel):
email: str
code: str
token: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailResetPayload(BaseModel):
new_email: str
token: str
@field_validator("new_email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class CheckEmailUniquePayload(BaseModel):
email: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
console_ns.schema_model(
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountInterfaceLanguagePayload.__name__,
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountInterfaceThemePayload.__name__,
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountTimezonePayload.__name__,
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountPasswordPayload.__name__,
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletePayload.__name__,
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletionFeedbackPayload.__name__,
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationActivatePayload.__name__,
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationAutocompleteQuery.__name__,
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailSendPayload.__name__,
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailValidityPayload.__name__,
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailResetPayload.__name__,
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
CheckEmailUniquePayload.__name__,
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@console_ns.expect(_init_parser())
@console_ns.expect(console_ns.models[AccountInitPayload.__name__])
@setup_required
@login_required
def post(self):
@ -64,17 +244,18 @@ class AccountInitApi(Resource):
if account.status == "active":
raise AccountAlreadyInitedError()
args = _init_parser().parse_args()
payload = console_ns.payload or {}
args = AccountInitPayload.model_validate(payload)
if dify_config.EDITION == "CLOUD":
if not args["invitation_code"]:
if not args.invitation_code:
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
.where(
InvitationCode.code == args["invitation_code"],
InvitationCode.code == args.invitation_code,
InvitationCode.status == "unused",
)
.first()
@ -88,8 +269,8 @@ class AccountInitApi(Resource):
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_theme = "light"
account.status = "active"
account.initialized_at = naive_utc_now()
@ -110,137 +291,104 @@ class AccountProfileApi(Resource):
return current_user
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/account/name")
class AccountNameApi(Resource):
@console_ns.expect(parser_name)
@console_ns.expect(console_ns.models[AccountNamePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_name.parse_args()
# Validate account name length
if len(args["name"]) < 3 or len(args["name"]) > 30:
raise ValueError("Account name must be between 3 and 30 characters.")
updated_account = AccountService.update_account(current_user, name=args["name"])
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
return updated_account
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@console_ns.expect(parser_avatar)
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_avatar.parse_args()
payload = console_ns.payload or {}
args = AccountAvatarPayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return updated_account
parser_interface = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
@console_ns.expect(parser_interface)
@console_ns.expect(console_ns.models[AccountInterfaceLanguagePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_interface.parse_args()
payload = console_ns.payload or {}
args = AccountInterfaceLanguagePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return updated_account
parser_theme = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
@console_ns.expect(parser_theme)
@console_ns.expect(console_ns.models[AccountInterfaceThemePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_theme.parse_args()
payload = console_ns.payload or {}
args = AccountInterfaceThemePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return updated_account
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
@console_ns.expect(parser_timezone)
@console_ns.expect(console_ns.models[AccountTimezonePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_timezone.parse_args()
payload = console_ns.payload or {}
args = AccountTimezonePayload.model_validate(payload)
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
if args["timezone"] not in pytz.all_timezones:
raise ValueError("Invalid timezone string.")
updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return updated_account
parser_pw = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
@console_ns.expect(parser_pw)
@console_ns.expect(console_ns.models[AccountPasswordPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_pw.parse_args()
if args["new_password"] != args["repeat_new_password"]:
raise RepeatPasswordNotMatchError()
payload = console_ns.payload or {}
args = AccountPasswordPayload.model_validate(payload)
try:
AccountService.update_account_password(current_user, args["password"], args["new_password"])
AccountService.update_account_password(current_user, args.password, args.new_password)
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
@ -316,25 +464,19 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
parser_delete = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
@console_ns.expect(parser_delete)
@console_ns.expect(console_ns.models[AccountDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
args = parser_delete.parse_args()
payload = console_ns.payload or {}
args = AccountDeletePayload.model_validate(payload)
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
if not AccountService.verify_account_deletion_code(args.token, args.code):
raise InvalidAccountDeletionCodeError()
AccountService.delete_account(account)
@ -342,21 +484,15 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
parser_feedback = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
@console_ns.expect(parser_feedback)
@console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__])
@setup_required
def post(self):
args = parser_feedback.parse_args()
payload = console_ns.payload or {}
args = AccountDeletionFeedbackPayload.model_validate(payload)
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
BillingService.update_account_deletion_feedback(args.email, args.feedback)
return {"result": "success"}
@ -379,14 +515,6 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
parser_edu = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
@ -396,7 +524,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean,
}
@console_ns.expect(parser_edu)
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -405,9 +533,10 @@ class EducationApi(Resource):
def post(self):
account, _ = current_account_with_tenant()
args = parser_edu.parse_args()
payload = console_ns.payload or {}
args = EducationActivatePayload.model_validate(payload)
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
return BillingService.EducationIdentity.activate(account, args.token, args.institution, args.role)
@setup_required
@login_required
@ -425,14 +554,6 @@ class EducationApi(Resource):
return res
parser_autocomplete = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
@ -441,7 +562,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean,
}
@console_ns.expect(parser_autocomplete)
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -449,46 +570,39 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
args = parser_autocomplete.parse_args()
payload = request.args.to_dict(flat=True) # type: ignore
args = EducationAutocompleteQuery.model_validate(payload)
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
parser_change_email = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
@console_ns.expect(parser_change_email)
@console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_change_email.parse_args()
payload = console_ns.payload or {}
args = ChangeEmailSendPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = None
user_email = args["email"]
if args["phase"] is not None and args["phase"] == "new_email":
if args["token"] is None:
user_email = args.email
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
reset_data = AccountService.get_change_email_data(args["token"])
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
@ -497,118 +611,103 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidEmailError()
else:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
if account is None:
raise AccountNotFound()
token = AccountService.send_change_email_email(
account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
)
return {"result": "success", "data": token}
parser_validity = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
@console_ns.expect(parser_validity)
@console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
args = parser_validity.parse_args()
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args["email"]
user_email = args.email
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
token_data = AccountService.get_change_email_data(args["token"])
token_data = AccountService.get_change_email_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args["email"])
if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args["token"])
AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token(
user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
AccountService.reset_change_email_error_rate_limit(args["email"])
AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_reset = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
@console_ns.expect(parser_reset)
@console_ns.expect(console_ns.models[ChangeEmailResetPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
args = parser_reset.parse_args()
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
if AccountService.is_account_in_freeze(args["new_email"]):
if AccountService.is_account_in_freeze(args.new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["new_email"]):
if not AccountService.check_email_unique(args.new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args["token"])
reset_data = AccountService.get_change_email_data(args.token)
if not reset_data:
raise InvalidTokenError()
AccountService.revoke_change_email_token(args["token"])
AccountService.revoke_change_email_token(args.token)
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=args["new_email"])
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
AccountService.send_change_email_completed_notify_email(
email=args["new_email"],
email=args.new_email,
)
return updated_account
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
@console_ns.expect(parser_check)
@console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__])
@setup_required
def post(self):
args = parser_check.parse_args()
if AccountService.is_account_in_freeze(args["email"]):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
if AccountService.is_account_in_freeze(args.email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["email"]):
if not AccountService.check_email_unique(args.email):
raise EmailAlreadyInUseError()
return {"result": "success"}

View File

@ -1,7 +1,8 @@
from urllib import parse
from flask import abort, request
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
import services
from configs import dify_config
@ -31,6 +32,53 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class MemberInvitePayload(BaseModel):
emails: list[str] = Field(default_factory=list)
role: TenantAccountRole
language: str | None = None
class MemberRoleUpdatePayload(BaseModel):
role: str
class OwnerTransferEmailPayload(BaseModel):
language: str | None = None
class OwnerTransferCheckPayload(BaseModel):
code: str
token: str
class OwnerTransferPayload(BaseModel):
token: str
console_ns.schema_model(
MemberInvitePayload.__name__,
MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
MemberRoleUpdatePayload.__name__,
MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferEmailPayload.__name__,
OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferCheckPayload.__name__,
OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferPayload.__name__,
OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
@ -48,29 +96,22 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_invite = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
@console_ns.expect(parser_invite)
@console_ns.expect(console_ns.models[MemberInvitePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
args = parser_invite.parse_args()
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
invitee_emails = args["emails"]
invitee_role = args["role"]
interface_language = args["language"]
invitee_emails = args.emails
invitee_role = args.role
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
@ -146,20 +187,18 @@ class MemberCancelInviteApi(Resource):
}, 200
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
@console_ns.expect(parser_update)
@console_ns.expect(console_ns.models[MemberRoleUpdatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def put(self, member_id):
args = parser_update.parse_args()
new_role = args["role"]
payload = console_ns.payload or {}
args = MemberRoleUpdatePayload.model_validate(payload)
new_role = args.role
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
@ -197,20 +236,18 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
@console_ns.expect(parser_send)
@console_ns.expect(console_ns.models[OwnerTransferEmailPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
args = parser_send.parse_args()
payload = console_ns.payload or {}
args = OwnerTransferEmailPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@ -221,7 +258,7 @@ class SendOwnerTransferEmailApi(Resource):
if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
@ -238,22 +275,16 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
parser_owner = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
@console_ns.expect(parser_owner)
@console_ns.expect(console_ns.models[OwnerTransferCheckPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
args = parser_owner.parse_args()
payload = console_ns.payload or {}
args = OwnerTransferCheckPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@ -267,41 +298,37 @@ class OwnerTransferCheckApi(Resource):
if is_owner_transfer_error_rate_limit:
raise OwnerTransferLimitError()
token_data = AccountService.get_owner_transfer_data(args["token"])
token_data = AccountService.get_owner_transfer_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
if args.code != token_data.get("code"):
AccountService.add_owner_transfer_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_owner_transfer_token(args["token"])
AccountService.revoke_owner_transfer_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={})
_, new_token = AccountService.generate_owner_transfer_token(user_email, code=args.code, additional_data={})
AccountService.reset_owner_transfer_error_rate_limit(user_email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_owner_transfer = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource):
@console_ns.expect(parser_owner_transfer)
@console_ns.expect(console_ns.models[OwnerTransferPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id):
args = parser_owner_transfer.parse_args()
payload = console_ns.payload or {}
args = OwnerTransferPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
@ -313,14 +340,14 @@ class OwnerTransfer(Resource):
if current_user.id == str(member_id):
raise CannotTransferOwnerToSelfError()
transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
transfer_token_data = AccountService.get_owner_transfer_data(args.token)
if not transfer_token_data:
raise InvalidTokenError()
if transfer_token_data.get("email") != current_user.email:
raise InvalidEmailError()
AccountService.revoke_owner_transfer_token(args["token"])
AccountService.revoke_owner_transfer_token(args.token)
member = db.session.get(Account, str(member_id))
if not member:

View File

@ -1,31 +1,123 @@
import io
from typing import Any, Literal
from flask import send_file
from flask_restx import Resource, reqparse
from flask import request, send_file
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
parser_model = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=False,
nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ParserModelList(BaseModel):
model_type: ModelType | None = None
class ParserCredentialId(BaseModel):
credential_id: str | None = None
@field_validator("credential_id")
@classmethod
def validate_optional_credential_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ParserCredentialCreate(BaseModel):
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
class ParserCredentialUpdate(BaseModel):
credential_id: str
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
@field_validator("credential_id")
@classmethod
def validate_update_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserCredentialDelete(BaseModel):
credential_id: str
@field_validator("credential_id")
@classmethod
def validate_delete_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserCredentialSwitch(BaseModel):
credential_id: str
@field_validator("credential_id")
@classmethod
def validate_switch_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserCredentialValidate(BaseModel):
credentials: dict[str, Any]
class ParserPreferredProviderType(BaseModel):
preferred_provider_type: Literal["system", "custom"]
console_ns.schema_model(
ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserCredentialId.__name__,
ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialCreate.__name__,
ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialUpdate.__name__,
ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialDelete.__name__,
ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialSwitch.__name__,
ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialValidate.__name__,
ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferredProviderType.__name__,
ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
@console_ns.expect(parser_model)
@console_ns.expect(console_ns.models[ParserModelList.__name__])
@setup_required
@login_required
@account_initialization_required
@ -33,38 +125,18 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
args = parser_model.parse_args()
payload = request.args.to_dict(flat=True) # type: ignore
args = ParserModelList.model_validate(payload)
model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.model_type)
return jsonable_encoder({"data": provider_list})
parser_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource):
@console_ns.expect(parser_cred)
@console_ns.expect(console_ns.models[ParserCredentialId.__name__])
@setup_required
@login_required
@account_initialization_required
@ -72,23 +144,25 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
args = parser_cred.parse_args()
payload = request.args.to_dict(flat=True) # type: ignore
args = ParserCredentialId.model_validate(payload)
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credential(
tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
tenant_id=tenant_id, provider=provider, credential_id=args.credential_id
)
return {"credentials": credentials}
@console_ns.expect(parser_post_cred)
@console_ns.expect(console_ns.models[ParserCredentialCreate.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_post_cred.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialCreate.model_validate(payload)
model_provider_service = ModelProviderService()
@ -96,15 +170,15 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.create_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_name=args["name"],
credentials=args.credentials,
credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
@console_ns.expect(parser_put_cred)
@console_ns.expect(console_ns.models[ParserCredentialUpdate.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -112,7 +186,8 @@ class ModelProviderCredentialApi(Resource):
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_put_cred.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialUpdate.model_validate(payload)
model_provider_service = ModelProviderService()
@ -120,71 +195,64 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.update_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
credential_name=args["name"],
credentials=args.credentials,
credential_id=args.credential_id,
credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}
@console_ns.expect(parser_delete_cred)
@console_ns.expect(console_ns.models[ParserCredentialDelete.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_delete_cred.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialDelete.model_validate(payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
)
return {"result": "success"}, 204
parser_switch = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
@console_ns.expect(parser_switch)
@console_ns.expect(console_ns.models[ParserCredentialSwitch.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_switch.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialSwitch.model_validate(payload)
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
credential_id=args["credential_id"],
credential_id=args.credential_id,
)
return {"result": "success"}
parser_validate = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource):
@console_ns.expect(parser_validate)
@console_ns.expect(console_ns.models[ParserCredentialValidate.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_validate.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialValidate.model_validate(payload)
tenant_id = current_tenant_id
@ -195,7 +263,7 @@ class ModelProviderValidateApi(Resource):
try:
model_provider_service.validate_provider_credentials(
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
tenant_id=tenant_id, provider=provider, credentials=args.credentials
)
except CredentialsValidateFailedError as ex:
result = False
@ -228,19 +296,9 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
parser_preferred = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
@console_ns.expect(parser_preferred)
@console_ns.expect(console_ns.models[ParserPreferredProviderType.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -250,11 +308,12 @@ class PreferredProviderTypeUpdateApi(Resource):
tenant_id = current_tenant_id
args = parser_preferred.parse_args()
payload = console_ns.payload or {}
args = ParserPreferredProviderType.model_validate(payload)
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(
tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
tenant_id=tenant_id, provider=provider, preferred_provider_type=args.preferred_provider_type
)
return {"result": "success"}

View File

@ -1,52 +1,172 @@
import logging
from typing import Any, cast
from flask_restx import Resource, reqparse
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
parser_get_default = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
class ParserGetDefault(BaseModel):
model_type: ModelType
class ParserPostDefault(BaseModel):
class Inner(BaseModel):
model_type: ModelType
model: str | None = None
provider: str | None = None
model_settings: list[Inner]
console_ns.schema_model(
ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
parser_post_default = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
console_ns.schema_model(
ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserDeleteModels(BaseModel):
model: str
model_type: ModelType
console_ns.schema_model(
ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class LoadBalancingPayload(BaseModel):
configs: list[dict[str, Any]] | None = None
enabled: bool | None = None
class ParserPostModels(BaseModel):
model: str
model_type: ModelType
load_balancing: LoadBalancingPayload | None = None
config_from: str | None = None
credential_id: str | None = None
@field_validator("credential_id")
@classmethod
def validate_credential_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ParserGetCredentials(BaseModel):
model: str
model_type: ModelType
config_from: str | None = None
credential_id: str | None = None
@field_validator("credential_id")
@classmethod
def validate_get_credential_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ParserCredentialBase(BaseModel):
model: str
model_type: ModelType
class ParserCreateCredential(ParserCredentialBase):
name: str | None = Field(default=None, max_length=30)
credentials: dict[str, Any]
class ParserUpdateCredential(ParserCredentialBase):
credential_id: str
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
@field_validator("credential_id")
@classmethod
def validate_update_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserDeleteCredential(ParserCredentialBase):
credential_id: str
@field_validator("credential_id")
@classmethod
def validate_delete_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserParameter(BaseModel):
model: str
console_ns.schema_model(
ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserGetCredentials.__name__,
ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCreateCredential.__name__,
ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserUpdateCredential.__name__,
ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserDeleteCredential.__name__,
ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
@console_ns.expect(parser_get_default)
@console_ns.expect(console_ns.models[ParserGetDefault.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_get_default.parse_args()
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args["model_type"]
tenant_id=tenant_id, model_type=args.model_type
)
return jsonable_encoder({"data": default_model_entity})
@console_ns.expect(parser_post_default)
@console_ns.expect(console_ns.models[ParserPostDefault.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -54,66 +174,31 @@ class DefaultModelApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_post_default.parse_args()
args = ParserPostDefault.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
model_settings = args.model_settings
for model_setting in model_settings:
if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
raise ValueError("invalid model type")
if "provider" not in model_setting:
if model_setting.provider is None:
continue
if "model" not in model_setting:
raise ValueError("invalid model")
try:
model_provider_service.update_default_model_of_model_type(
tenant_id=tenant_id,
model_type=model_setting["model_type"],
provider=model_setting["provider"],
model=model_setting["model"],
model_type=model_setting.model_type,
provider=model_setting.provider,
model=cast(str, model_setting.model),
)
except Exception as ex:
logger.exception(
"Failed to update default model, model type: %s, model: %s",
model_setting["model_type"],
model_setting.get("model"),
model_setting.model_type,
model_setting.model,
)
raise ex
return {"result": "success"}
parser_post_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
parser_delete_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource):
@setup_required
@ -127,7 +212,7 @@ class ModelProviderModelApi(Resource):
return jsonable_encoder({"data": models})
@console_ns.expect(parser_post_models)
@console_ns.expect(console_ns.models[ParserPostModels.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -135,45 +220,45 @@ class ModelProviderModelApi(Resource):
def post(self, provider: str):
# To save the model's load balance configs
_, tenant_id = current_account_with_tenant()
args = parser_post_models.parse_args()
args = ParserPostModels.model_validate(console_ns.payload)
if args.get("config_from", "") == "custom-model":
if not args.get("credential_id"):
if args.config_from == "custom-model":
if not args.credential_id:
raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService()
service.switch_active_custom_model_credential(
tenant_id=tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credential_id=args["credential_id"],
model_type=args.model_type,
model=args.model,
credential_id=args.credential_id,
)
model_load_balancing_service = ModelLoadBalancingService()
if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
if args.load_balancing and args.load_balancing.configs:
# save load balancing configs
model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
configs=args["load_balancing"]["configs"],
config_from=args.get("config_from", ""),
model=args.model,
model_type=args.model_type,
configs=args.load_balancing.configs,
config_from=args.config_from or "",
)
if args.get("load_balancing", {}).get("enabled"):
if args.load_balancing.enabled:
model_load_balancing_service.enable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
else:
model_load_balancing_service.disable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 200
@console_ns.expect(parser_delete_models)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
@setup_required
@login_required
@is_admin_or_owner_required
@ -181,113 +266,53 @@ class ModelProviderModelApi(Resource):
def delete(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_delete_models.parse_args()
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_model(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 204
parser_get_credentials = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource):
@console_ns.expect(parser_get_credentials)
@console_ns.expect(console_ns.models[ParserGetCredentials.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_get_credentials.parse_args()
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore
model_provider_service = ModelProviderService()
current_credential = model_provider_service.get_model_credential(
tenant_id=tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credential_id=args.get("credential_id"),
model_type=args.model_type,
model=args.model,
credential_id=args.credential_id,
)
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
config_from=args.get("config_from", ""),
model=args.model,
model_type=args.model_type,
config_from=args.config_from or "",
)
if args.get("config_from", "") == "predefined-model":
if args.config_from == "predefined-model":
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
tenant_id=tenant_id, provider_name=provider
)
else:
model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
model_type = args.model_type
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
)
return jsonable_encoder(
@ -304,7 +329,7 @@ class ModelProviderModelCredentialApi(Resource):
}
)
@console_ns.expect(parser_post_cred)
@console_ns.expect(console_ns.models[ParserCreateCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -312,7 +337,7 @@ class ModelProviderModelCredentialApi(Resource):
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_post_cred.parse_args()
args = ParserCreateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -320,30 +345,30 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.create_model_credential(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
credential_name=args["name"],
model=args.model,
model_type=args.model_type,
credentials=args.credentials,
credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
logger.exception(
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
tenant_id,
args.get("model"),
args.get("model_type"),
args.model,
args.model_type,
)
raise ValueError(str(ex))
return {"result": "success"}, 201
@console_ns.expect(parser_put_cred)
@console_ns.expect(console_ns.models[ParserUpdateCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_put_cred.parse_args()
args = ParserUpdateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -351,106 +376,87 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.update_model_credential(
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credentials=args["credentials"],
credential_id=args["credential_id"],
credential_name=args["name"],
model_type=args.model_type,
model=args.model,
credentials=args.credentials,
credential_id=args.credential_id,
credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}
@console_ns.expect(parser_delete_cred)
@console_ns.expect(console_ns.models[ParserDeleteCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_delete_cred.parse_args()
args = ParserDeleteCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credential_id=args["credential_id"],
model_type=args.model_type,
model=args.model,
credential_id=args.credential_id,
)
return {"result": "success"}, 204
parser_switch = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
class ParserSwitch(BaseModel):
model: str
model_type: ModelType
credential_id: str
console_ns.schema_model(
ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
@console_ns.expect(parser_switch)
@console_ns.expect(console_ns.models[ParserSwitch.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_switch.parse_args()
args = ParserSwitch.model_validate(console_ns.payload)
service = ModelProviderService()
service.add_model_credential_to_model_list(
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credential_id=args["credential_id"],
model_type=args.model_type,
model=args.model,
credential_id=args.credential_id,
)
return {"result": "success"}
parser_model_enable_disable = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
@console_ns.expect(parser_model_enable_disable)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_model_enable_disable.parse_args()
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}
@ -460,48 +466,43 @@ class ModelProviderModelEnableApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
@console_ns.expect(parser_model_enable_disable)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_model_enable_disable.parse_args()
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}
parser_validate = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
class ParserValidate(BaseModel):
model: str
model_type: ModelType
credentials: dict
console_ns.schema_model(
ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
@console_ns.expect(parser_validate)
@console_ns.expect(console_ns.models[ParserValidate.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_validate.parse_args()
args = ParserValidate.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -512,9 +513,9 @@ class ModelProviderModelValidateApi(Resource):
model_provider_service.validate_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
model=args.model,
model_type=args.model_type,
credentials=args.credentials,
)
except CredentialsValidateFailedError as ex:
result = False
@ -528,24 +529,19 @@ class ModelProviderModelValidateApi(Resource):
return response
parser_parameter = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
@console_ns.expect(parser_parameter)
@console_ns.expect(console_ns.models[ParserParameter.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
args = parser_parameter.parse_args()
args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
tenant_id=tenant_id, provider=provider, model=args["model"]
tenant_id=tenant_id, provider=provider, model=args.model
)
return jsonable_encoder({"data": parameter_rules})

View File

@ -1,7 +1,9 @@
import io
from typing import Literal
from flask import request, send_file
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -17,6 +19,8 @@ from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@ -37,88 +41,251 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
parser_list = (
reqparse.RequestParser()
.add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
class ParserList(BaseModel):
page: int = Field(default=1)
page_size: int = Field(default=256)
console_ns.schema_model(
ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@console_ns.expect(parser_list)
@console_ns.expect(console_ns.models[ParserList.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_list.parse_args()
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
class ParserLatest(BaseModel):
plugin_ids: list[str]
console_ns.schema_model(
ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserIcon(BaseModel):
tenant_id: str
filename: str
class ParserAsset(BaseModel):
plugin_unique_identifier: str
file_name: str
class ParserGithubUpload(BaseModel):
repo: str
version: str
package: str
class ParserPluginIdentifiers(BaseModel):
plugin_unique_identifiers: list[str]
class ParserGithubInstall(BaseModel):
plugin_unique_identifier: str
repo: str
version: str
package: str
class ParserPluginIdentifierQuery(BaseModel):
plugin_unique_identifier: str
class ParserTasks(BaseModel):
page: int
page_size: int
class ParserMarketplaceUpgrade(BaseModel):
original_plugin_unique_identifier: str
new_plugin_unique_identifier: str
class ParserGithubUpgrade(BaseModel):
original_plugin_unique_identifier: str
new_plugin_unique_identifier: str
repo: str
version: str
package: str
class ParserUninstall(BaseModel):
plugin_installation_id: str
class ParserPermissionChange(BaseModel):
install_permission: TenantPluginPermission.InstallPermission
debug_permission: TenantPluginPermission.DebugPermission
class ParserDynamicOptions(BaseModel):
plugin_id: str
provider: str
action: str
parameter: str
credential_id: str | None = None
provider_type: Literal["tool", "trigger"]
class PluginPermissionSettingsPayload(BaseModel):
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
class PluginAutoUpgradeSettingsPayload(BaseModel):
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting = (
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY
)
upgrade_time_of_day: int = 0
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
exclude_plugins: list[str] = Field(default_factory=list)
include_plugins: list[str] = Field(default_factory=list)
class ParserPreferencesChange(BaseModel):
permission: PluginPermissionSettingsPayload
auto_upgrade: PluginAutoUpgradeSettingsPayload
class ParserExcludePlugin(BaseModel):
plugin_id: str
class ParserReadme(BaseModel):
plugin_unique_identifier: str
language: str = Field(default="en-US")
console_ns.schema_model(
ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPluginIdentifiers.__name__,
ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPluginIdentifierQuery.__name__,
ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserMarketplaceUpgrade.__name__,
ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPermissionChange.__name__,
ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserDynamicOptions.__name__,
ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferencesChange.__name__,
ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserExcludePlugin.__name__,
ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
@console_ns.expect(parser_latest)
@console_ns.expect(console_ns.models[ParserLatest.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
args = parser_latest.parse_args()
args = ParserLatest.model_validate(console_ns.payload)
try:
versions = PluginService.list_latest_versions(args["plugin_ids"])
versions = PluginService.list_latest_versions(args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"versions": versions})
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
@console_ns.expect(parser_ids)
@console_ns.expect(console_ns.models[ParserLatest.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_ids.parse_args()
args = ParserLatest.model_validate(console_ns.payload)
try:
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
parser_icon = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
@console_ns.expect(parser_icon)
@console_ns.expect(console_ns.models[ParserIcon.__name__])
@setup_required
def get(self):
args = parser_icon.parse_args()
args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -128,20 +295,16 @@ class PluginIconApi(Resource):
@console_ns.route("/workspaces/current/plugin/asset")
class PluginAssetApi(Resource):
@console_ns.expect(console_ns.models[ParserAsset.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
req = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
.add_argument("file_name", type=str, required=True, location="args")
)
args = req.parse_args()
args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore
_, tenant_id = current_account_with_tenant()
try:
binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -171,17 +334,9 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
parser_github = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
@console_ns.expect(parser_github)
@console_ns.expect(console_ns.models[ParserGithubUpload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -189,10 +344,10 @@ class PluginUploadFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_github.parse_args()
args = ParserGithubUpload.model_validate(console_ns.payload)
try:
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -223,47 +378,28 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
parser_pkg = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
@console_ns.expect(parser_pkg)
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_pkg.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
parser_githubapi = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
@console_ns.expect(parser_githubapi)
@console_ns.expect(console_ns.models[ParserGithubInstall.__name__])
@setup_required
@login_required
@account_initialization_required
@ -271,15 +407,15 @@ class PluginInstallFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_githubapi.parse_args()
args = ParserGithubInstall.model_validate(console_ns.payload)
try:
response = PluginService.install_from_github(
tenant_id,
args["plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
args.plugin_unique_identifier,
args.repo,
args.version,
args.package,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -287,14 +423,9 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
parser_marketplace = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
@console_ns.expect(parser_marketplace)
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
@setup_required
@login_required
@account_initialization_required
@ -302,43 +433,33 @@ class PluginInstallFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_marketplace.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
parser_pkgapi = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
@console_ns.expect(parser_pkgapi)
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_pkgapi.parse_args()
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_marketplace_pkg(
tenant_id,
args["plugin_unique_identifier"],
args.plugin_unique_identifier,
)
}
)
@ -346,14 +467,9 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
parser_fetch = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
@console_ns.expect(parser_fetch)
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -361,30 +477,19 @@ class PluginFetchManifestApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_fetch.parse_args()
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_plugin_manifest(
tenant_id, args["plugin_unique_identifier"]
).model_dump()
}
{"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
parser_tasks = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
@console_ns.expect(parser_tasks)
@console_ns.expect(console_ns.models[ParserTasks.__name__])
@setup_required
@login_required
@account_initialization_required
@ -392,12 +497,10 @@ class PluginFetchInstallTasksApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_tasks.parse_args()
args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
)
return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)})
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -462,16 +565,9 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
parser_marketplace_api = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
@console_ns.expect(parser_marketplace_api)
@console_ns.expect(console_ns.models[ParserMarketplaceUpgrade.__name__])
@setup_required
@login_required
@account_initialization_required
@ -479,31 +575,21 @@ class PluginUpgradeFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_marketplace_api.parse_args()
args = ParserMarketplaceUpgrade.model_validate(console_ns.payload)
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_marketplace(
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
tenant_id, args.original_plugin_unique_identifier, args.new_plugin_unique_identifier
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
parser_github_post = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
@console_ns.expect(parser_github_post)
@console_ns.expect(console_ns.models[ParserGithubUpgrade.__name__])
@setup_required
@login_required
@account_initialization_required
@ -511,56 +597,44 @@ class PluginUpgradeFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_github_post.parse_args()
args = ParserGithubUpgrade.model_validate(console_ns.payload)
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_github(
tenant_id,
args["original_plugin_unique_identifier"],
args["new_plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
args.original_plugin_unique_identifier,
args.new_plugin_unique_identifier,
args.repo,
args.version,
args.package,
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
parser_uninstall = reqparse.RequestParser().add_argument(
"plugin_installation_id", type=str, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
@console_ns.expect(parser_uninstall)
@console_ns.expect(console_ns.models[ParserUninstall.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
args = parser_uninstall.parse_args()
args = ParserUninstall.model_validate(console_ns.payload)
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
parser_change_post = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
@console_ns.expect(parser_change_post)
@console_ns.expect(console_ns.models[ParserPermissionChange.__name__])
@setup_required
@login_required
@account_initialization_required
@ -570,14 +644,15 @@ class PluginChangePermissionApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
args = parser_change_post.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
args = ParserPermissionChange.model_validate(console_ns.payload)
tenant_id = current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
return {
"success": PluginPermissionService.change_permission(
tenant_id, args.install_permission, args.debug_permission
)
}
@console_ns.route("/workspaces/current/plugin/permission/fetch")
@ -605,20 +680,9 @@ class PluginFetchPermissionApi(Resource):
)
parser_dynamic = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("credential_id", type=str, required=False, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
@console_ns.expect(parser_dynamic)
@console_ns.expect(console_ns.models[ParserDynamicOptions.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -627,18 +691,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
args = parser_dynamic.parse_args()
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args["plugin_id"],
provider=args["provider"],
action=args["action"],
parameter=args["parameter"],
credential_id=args["credential_id"],
provider_type=args["provider_type"],
plugin_id=args.plugin_id,
provider=args.provider,
action=args.action,
parameter=args.parameter,
credential_id=args.credential_id,
provider_type=args.provider_type,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -646,16 +710,9 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
parser_change = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@console_ns.expect(parser_change)
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
@setup_required
@login_required
@account_initialization_required
@ -664,22 +721,20 @@ class PluginChangePreferencesApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
args = parser_change.parse_args()
args = ParserPreferencesChange.model_validate(console_ns.payload)
permission = args["permission"]
permission = args.permission
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
install_permission = permission.install_permission
debug_permission = permission.debug_permission
auto_upgrade = args["auto_upgrade"]
auto_upgrade = args.auto_upgrade
strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
auto_upgrade.get("strategy_setting", "fix_only")
)
upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
exclude_plugins = auto_upgrade.get("exclude_plugins", [])
include_plugins = auto_upgrade.get("include_plugins", [])
strategy_setting = auto_upgrade.strategy_setting
upgrade_time_of_day = auto_upgrade.upgrade_time_of_day
upgrade_mode = auto_upgrade.upgrade_mode
exclude_plugins = auto_upgrade.exclude_plugins
include_plugins = auto_upgrade.include_plugins
# set permission
set_permission_result = PluginPermissionService.change_permission(
@ -744,12 +799,9 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
@console_ns.expect(parser_exclude)
@console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
@setup_required
@login_required
@account_initialization_required
@ -757,28 +809,20 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
# exclude one single plugin
_, tenant_id = current_account_with_tenant()
args = parser_exclude.parse_args()
args = ParserExcludePlugin.model_validate(console_ns.payload)
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
@console_ns.route("/workspaces/current/plugin/readme")
class PluginReadmeApi(Resource):
@console_ns.expect(console_ns.models[ParserReadme.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
.add_argument("language", type=str, required=False, location="args")
)
args = parser.parse_args()
args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore
return jsonable_encoder(
{
"readme": PluginService.fetch_plugin_readme(
tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
)
}
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
)

View File

@ -1,7 +1,8 @@
import logging
from flask import request
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@ -32,6 +33,45 @@ from services.file_service import FileService
from services.workspace_service import WorkspaceService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkspaceListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
class SwitchWorkspacePayload(BaseModel):
tenant_id: str
class WorkspaceCustomConfigPayload(BaseModel):
remove_webapp_brand: bool | None = None
replace_webapp_logo: str | None = None
class WorkspaceInfoPayload(BaseModel):
name: str
console_ns.schema_model(
WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
SwitchWorkspacePayload.__name__,
SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceCustomConfigPayload.__name__,
WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceInfoPayload.__name__,
WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
provider_fields = {
@ -95,18 +135,15 @@ class TenantListApi(Resource):
@console_ns.route("/all-workspaces")
class WorkspaceListApi(Resource):
@console_ns.expect(console_ns.models[WorkspaceListQuery.__name__])
@setup_required
@admin_required
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
payload = request.args.to_dict(flat=True) # type: ignore
args = WorkspaceListQuery.model_validate(payload)
stmt = select(Tenant).order_by(Tenant.created_at.desc())
tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
tenants = db.paginate(select=stmt, page=args.page, per_page=args.limit, error_out=False)
has_more = False
if tenants.has_next:
@ -115,8 +152,8 @@ class WorkspaceListApi(Resource):
return {
"data": marshal(tenants.items, workspace_fields),
"has_more": has_more,
"limit": args["limit"],
"page": args["page"],
"limit": args.limit,
"page": args.page,
"total": tenants.total,
}, 200
@ -150,26 +187,24 @@ class TenantApi(Resource):
return WorkspaceService.get_tenant_info(tenant), 200
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource):
@console_ns.expect(parser_switch)
@console_ns.expect(console_ns.models[SwitchWorkspacePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_switch.parse_args()
payload = console_ns.payload or {}
args = SwitchWorkspacePayload.model_validate(payload)
# check if tenant_id is valid, 403 if not
try:
TenantService.switch_tenant(current_user, args["tenant_id"])
TenantService.switch_tenant(current_user, args.tenant_id)
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
if new_tenant is None:
raise ValueError("Tenant not found")
@ -178,24 +213,21 @@ class SwitchWorkspaceApi(Resource):
@console_ns.route("/workspaces/custom-config")
class CustomConfigWorkspaceApi(Resource):
@console_ns.expect(console_ns.models[WorkspaceCustomConfigPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("remove_webapp_brand", type=bool, location="json")
.add_argument("replace_webapp_logo", type=str, location="json")
)
args = parser.parse_args()
payload = console_ns.payload or {}
args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"],
"replace_webapp_logo": args["replace_webapp_logo"]
if args["replace_webapp_logo"] is not None
"remove_webapp_brand": args.remove_webapp_brand,
"replace_webapp_logo": args.replace_webapp_logo
if args.replace_webapp_logo is not None
else tenant.custom_config_dict.get("replace_webapp_logo"),
}
@ -245,24 +277,22 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource):
@console_ns.expect(parser_info)
@console_ns.expect(console_ns.models[WorkspaceInfoPayload.__name__])
@setup_required
@login_required
@account_initialization_required
# Change workspace name
def post(self):
_, current_tenant_id = current_account_with_tenant()
args = parser_info.parse_args()
payload = console_ns.payload or {}
args = WorkspaceInfoPayload.model_validate(payload)
if not current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_tenant_id)
tenant.name = args["name"]
tenant.name = args.name
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}

View File

@ -62,7 +62,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@ -72,7 +73,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
@ -580,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
with self._database_session() as session:
# Save message
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield workflow_finish_resp
elif event.stopped_by in (
@ -590,7 +591,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session:
# Save message
self._save_message(session=session)
self._save_message(session=session, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@ -599,6 +600,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
@ -616,7 +618,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@ -770,7 +772,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
def _save_message(
self,
*,
session: Session,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@ -809,6 +817,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))
# Extract model provider and model_id from workflow node executions for tracing
if message.workflow_run_id:
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
if model_info:
message.model_provider = model_info.get("provider")
message.model_id = model_info.get("model")
message_files = [
MessageFile(
message_id=message.id,
@ -826,6 +842,68 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
# Trigger MESSAGE_TRACE for tracing integrations
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
)
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
"""
Extract model provider and model_id from workflow node executions.
Returns dict with 'provider' and 'model' keys, or None if not found.
"""
try:
# Query workflow node executions for LLM or Agent nodes
stmt = (
select(WorkflowNodeExecutionModel)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.limit(1)
)
node_execution = session.scalar(stmt)
if not node_execution:
return None
# Try to extract from execution_metadata for agent nodes
if node_execution.execution_metadata:
try:
metadata = json.loads(node_execution.execution_metadata)
agent_log = metadata.get("agent_log", [])
# Look for the first agent thought with provider info
for log_entry in agent_log:
entry_metadata = log_entry.get("metadata", {})
provider_str = entry_metadata.get("provider")
if provider_str:
# Parse format like "langgenius/deepseek/deepseek"
parts = provider_str.split("/")
if len(parts) >= 3:
return {"provider": parts[1], "model": parts[2]}
elif len(parts) == 2:
return {"provider": parts[0], "model": parts[1]}
except (json.JSONDecodeError, KeyError, AttributeError) as e:
logger.debug("Failed to parse execution_metadata: %s", e)
# Try to extract from process_data for llm nodes
if node_execution.process_data:
try:
process_data = json.loads(node_execution.process_data)
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider and model:
return {"provider": provider, "model": model}
except (json.JSONDecodeError, KeyError) as e:
logger.debug("Failed to parse process_data: %s", e)
return None
except Exception as e:
logger.warning("Failed to extract model info from workflow: %s", e)
return None
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@ -258,6 +258,10 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
@ -414,9 +418,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
graph_runtime_state=validated_state,
)
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
def _handle_workflow_partial_success_event(
@ -437,10 +438,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count,
)
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
def _handle_workflow_failed_and_stop_events(
@ -471,10 +468,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
error=error,
exceptions_count=exceptions_count,
)
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
def _handle_text_chunk_event(
@ -655,7 +648,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: list[str] | None = None

View File

@ -40,6 +40,9 @@ class EasyUITaskState(TaskState):
"""
llm_result: LLMResult
first_token_time: float | None = None
last_token_time: float | None = None
is_streaming_response: bool = False
class WorkflowTaskState(TaskState):

View File

@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=state.dumps(),
pause_reasons=event.reasons,
)
def on_graph_end(self, error: Exception | None) -> None:

View File

@ -332,6 +332,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# Track streaming response times
if self._task_state.first_token_time is None:
self._task_state.first_token_time = time.perf_counter()
self._task_state.is_streaming_response = True
self._task_state.last_token_time = time.perf_counter()
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
@ -398,6 +404,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
# Add streaming metrics to usage if available
if self._task_state.is_streaming_response and self._task_state.first_token_time:
start_time = self.start_at
first_token_time = self._task_state.first_token_time
last_token_time = self._task_state.last_token_time or first_token_time
usage.time_to_first_token = round(first_token_time - start_time, 3)
usage.time_to_generate = round(last_token_time - first_token_time, 3)
# Update metadata with the complete usage info
self._task_state.metadata.usage = usage
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:

View File

@ -152,10 +152,5 @@ class CodeExecutor:
raise CodeExecutionError(f"Unsupported language {language}")
runner, preload = template_transformer.transform_caller(code, inputs)
try:
response = cls.execute_code(language, preload, runner)
except CodeExecutionError as e:
raise e
response = cls.execute_code(language, preload, runner)
return template_transformer.transform_response(response)

View File

@ -1,308 +0,0 @@
## Custom Integration of Pre-defined Models
### Introduction
After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration.
It is important to note that for custom models, each model connection requires a complete vendor credential.
Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file.
![](images/index/image-3.png)
As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user.
### Writing the Vendor YAML
First, we need to identify the types of models supported by the vendor we are integrating.
Currently supported model types are as follows:
- `llm` Text Generation Models
- `text_embedding` Text Embedding Models
- `rerank` Rerank Models
- `speech2text` Speech-to-Text
- `tts` Text-to-Speech
- `moderation` Moderation
Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml.
```yaml
provider: xinference #Define the vendor identifier
label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default.
en_US: Xorbits Inference
icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label
en_US: icon_s_en.svg
icon_large: # Large icon
en_US: icon_l_en.svg
help: # Help information
title:
en_US: How to deploy Xinference
zh_Hans: 如何部署 Xinference
url:
en_US: https://github.com/xorbitsai/inference
supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank
- llm
- text-embedding
- rerank
configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models.
- customizable-model
provider_credential_schema:
credential_form_schemas:
```
Then, we need to determine what credentials are required to define a model in Xinference.
- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
```yaml
provider_credential_schema:
credential_form_schemas:
- variable: model_type
type: select
label:
en_US: Model type
zh_Hans: 模型类型
required: true
options:
- value: text-generation
label:
en_US: Language Model
zh_Hans: 语言模型
- value: embeddings
label:
en_US: Text Embedding
- value: reranking
label:
en_US: Rerank
```
- Next, each model has its own model_name, so we need to define that here:
```yaml
- variable: model_name
type: text-input
label:
en_US: Model name
zh_Hans: 模型名称
required: true
placeholder:
zh_Hans: 填写模型名称
en_US: Input model name
```
- Specify the Xinference local deployment address:
```yaml
- variable: server_url
label:
zh_Hans: 服务器 URL
en_US: Server url
type: text-input
required: true
placeholder:
zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
```
- Each model has a unique model_uid, so we also need to define that here:
```yaml
- variable: model_uid
label:
zh_Hans: 模型 UID
en_US: Model uid
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 Model UID
en_US: Enter the model uid
```
Now, we have completed the basic definition of the vendor.
### Writing the Model Code
Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`.
In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
- LLM Invocation
Implement the core method for LLM invocation, supporting both stream and synchronous responses.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool usage
:param stop: stop words
:param stream: is the response a stream
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Heres an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- Pre-compute Input Tokens
If the model does not provide an interface for pre-computing tokens, you can return 0 directly.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool usage
:return: token count
"""
```
Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens and ensure environment variable `PLUGIN_BASED_TOKEN_COUNTING_ENABLED` is set to `true`, This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
- Model Credentials Validation
Similar to vendor credentials validation, this method validates individual model credentials.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return: None
"""
```
- Model Parameter Schema
Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema.
For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters.
However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature', type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度', en_US='Temperature'
)
),
ParameterRule(
name='top_p', type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P', en_US='Top P'
)
),
ParameterRule(
name='max_tokens', type=ParameterType.INT,
use_template='max_tokens',
min=1,
default=512,
label=I18nObject(
zh_Hans='最大生成长度', en_US='Max Tokens'
)
)
]
# if model is A, add top_k to rules
if model == 'A':
rules.append(
ParameterRule(
name='top_k', type=ParameterType.INT,
use_template='top_k',
min=1,
default=50,
label=I18nObject(
zh_Hans='Top K', en_US='Top K'
)
)
)
"""
some NOT IMPORTANT code here
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=model_type,
model_properties={
ModelPropertyKey.MODE: ModelType.LLM,
},
parameter_rules=rules
)
return entity
```
- Exception Error Mapping
When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately.
Runtime Errors:
- `InvokeConnectionError` Connection error during invocation
- `InvokeServerUnavailableError` Service provider unavailable
- `InvokeRateLimitError` Rate limit reached
- `InvokeAuthorizationError` Authorization failure
- `InvokeBadRequestError` Invalid request parameters
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 205 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 370 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 541 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

View File

@ -1,701 +0,0 @@
# Interface Methods
This section describes the interface methods and parameter explanations that need to be implemented by providers and various model types.
## Provider
Inherit the `__base.model_provider.ModelProvider` base class and implement the following interfaces:
```python
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
You can choose any validate_credentials method of model type or implement validate method by yourself,
such as: get model list api
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
```
- `credentials` (object) Credential information
The parameters of credential information are defined by the `provider_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
If verification fails, throw the `errors.validate.CredentialsValidateFailedError` error.
## Model
Models are divided into 5 different types, each inheriting from different base classes and requiring the implementation of different methods.
All models need to uniformly implement the following 2 methods:
- Model Credential Verification
Similar to provider credential verification, this step involves verification for an individual model.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
If verification fails, throw the `errors.validate.CredentialsValidateFailedError` error.
- Invocation Error Mapping Table
When there is an exception in model invocation, it needs to be mapped to the `InvokeError` type specified by Runtime. This facilitates Dify's ability to handle different errors with appropriate follow-up actions.
Runtime Errors:
- `InvokeConnectionError` Invocation connection error
- `InvokeServerUnavailableError` Invocation service provider unavailable
- `InvokeRateLimitError` Invocation reached rate limit
- `InvokeAuthorizationError` Invocation authorization failure
- `InvokeBadRequestError` Invocation parameter error
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
You can refer to OpenAI's `_invoke_error_mapping` for an example.
### LLM
Inherit the `__base.large_language_model.LargeLanguageModel` base class and implement the following interfaces:
- LLM Invocation
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) List of prompts
If the model is of the `Completion` type, the list only needs to include one [UserPromptMessage](#UserPromptMessage) element;
If the model is of the `Chat` type, it requires a list of elements such as [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) depending on the message.
- `model_parameters` (object) Model parameters
The model parameters are defined by the `parameter_rules` in the model's YAML configuration.
- `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] List of tools, equivalent to the `function` in `function calling`.
That is, the tool list for tool calling.
- `stop` (array[string]) [optional] Stop sequences
The model output will stop before the string defined by the stop sequence.
- `stream` (bool) Whether to output in a streaming manner, default is True
Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult).
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns
Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult).
- Pre-calculating Input Tokens
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
For parameter explanations, refer to the above section on `LLM Invocation`.
- Fetch Custom Model Schema [Optional]
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
```
When the provider supports adding custom LLMs, this method can be implemented to allow custom models to fetch model schema. The default return null.
### TextEmbedding
Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and implement the following interfaces:
- Embedding Invocation
```python
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `texts` (array[string]) List of texts, capable of batch processing
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns:
[TextEmbeddingResult](#TextEmbeddingResult) entity.
- Pre-calculating Tokens
```python
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
```
For parameter explanations, refer to the above section on `Embedding Invocation`.
### Rerank
Inherit the `__base.rerank_model.RerankModel` base class and implement the following interfaces:
- Rerank Invocation
```python
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `query` (string) Query request content
- `docs` (array[string]) List of segments to be reranked
- `score_threshold` (float) [optional] Score threshold
- `top_n` (int) [optional] Select the top n segments
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns:
[RerankResult](#RerankResult) entity.
### Speech2text
Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement the following interfaces:
- Invoke Invocation
```python
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `file` (File) File stream
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns:
The string after speech-to-text conversion.
### Text2speech
Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement the following interfaces:
- Invoke Invocation
```python
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: translated audio file
"""
```
- Parameters
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `content_text` (string) The text content that needs to be converted
- `streaming` (bool) Whether to stream output
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns
Text converted speech stream.
### Moderation
Inherit the `__base.moderation_model.ModerationModel` base class and implement the following interfaces:
- Invoke Invocation
```python
def _invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
```
- Parameters:
- `model` (string) Model name
- `credentials` (object) Credential information
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
- `text` (string) Text content
- `user` (string) [optional] Unique identifier of the user
This can help the provider monitor and detect abusive behavior.
- Returns:
False indicates that the input text is safe, True indicates otherwise.
## Entities
### PromptMessageRole
Message role
```python
class PromptMessageRole(Enum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
```
### PromptMessageContentType
Message content types, divided into text and image.
```python
class PromptMessageContentType(Enum):
"""
Enum class for prompt message content type.
"""
TEXT = 'text'
IMAGE = 'image'
```
### PromptMessageContent
Message content base class, used only for parameter declaration and cannot be initialized.
```python
class PromptMessageContent(BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
data: str
```
Currently, two types are supported: text and image. It's possible to simultaneously input text and multiple images.
You need to initialize `TextPromptMessageContent` and `ImagePromptMessageContent` separately for input.
### TextPromptMessageContent
```python
class TextPromptMessageContent(PromptMessageContent):
"""
Model class for text prompt message content.
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
```
If inputting a combination of text and images, the text needs to be constructed into this entity as part of the `content` list.
### ImagePromptMessageContent
```python
class ImagePromptMessageContent(PromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(Enum):
LOW = 'low'
HIGH = 'high'
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW # Resolution
```
If inputting a combination of text and images, the images need to be constructed into this entity as part of the `content` list.
`data` can be either a `url` or a `base64` encoded string of the image.
### PromptMessage
The base class for all Role message bodies, used only for parameter declaration and cannot be initialized.
```python
class PromptMessage(BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
name: Optional[str] = None
```
### UserPromptMessage
UserMessage message body, representing a user's message.
```python
class UserPromptMessage(PromptMessage):
"""
Model class for user prompt message.
"""
role: PromptMessageRole = PromptMessageRole.USER
```
### AssistantPromptMessage
Represents a message returned by the model, typically used for `few-shots` or inputting chat history.
```python
class AssistantPromptMessage(PromptMessage):
"""
Model class for assistant prompt message.
"""
class ToolCall(BaseModel):
"""
Model class for assistant prompt message tool call.
"""
class ToolCallFunction(BaseModel):
"""
Model class for assistant prompt message tool call function.
"""
name: str # tool name
arguments: str # tool arguments
id: str # Tool ID, effective only in OpenAI tool calls. It's the unique ID for tool invocation and the same tool can be called multiple times.
type: str # default: function
function: ToolCallFunction # tool call information
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
```
Where `tool_calls` are the list of `tool calls` returned by the model after invoking the model with the `tools` input.
### SystemPromptMessage
Represents system messages, usually used for setting system commands given to the model.
```python
class SystemPromptMessage(PromptMessage):
"""
Model class for system prompt message.
"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
```
### ToolPromptMessage
Represents tool messages, used for conveying the results of a tool execution to the model for the next step of processing.
```python
class ToolPromptMessage(PromptMessage):
"""
Model class for tool prompt message.
"""
role: PromptMessageRole = PromptMessageRole.TOOL
tool_call_id: str # Tool invocation ID. If OpenAI tool call is not supported, the name of the tool can also be inputted.
```
The base class's `content` takes in the results of tool execution.
### PromptMessageTool
```python
class PromptMessageTool(BaseModel):
"""
Model class for prompt message tool.
"""
name: str
description: str
parameters: dict
```
______________________________________________________________________
### LLMResult
```python
class LLMResult(BaseModel):
"""
Model class for llm result.
"""
model: str # Actual used modele
prompt_messages: list[PromptMessage] # prompt messages
message: AssistantPromptMessage # response message
usage: LLMUsage # usage info
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
```
### LLMResultChunkDelta
In streaming returns, each iteration contains the `delta` entity.
```python
class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
"""
index: int
message: AssistantPromptMessage # response message
usage: Optional[LLMUsage] = None # usage info
finish_reason: Optional[str] = None # finish reason, only the last one returns
```
### LLMResultChunk
Each iteration entity in streaming returns.
```python
class LLMResultChunk(BaseModel):
"""
Model class for llm result chunk.
"""
model: str # Actual used modele
prompt_messages: list[PromptMessage] # prompt messages
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
delta: LLMResultChunkDelta
```
### LLMUsage
```python
class LLMUsage(ModelUsage):
"""
Model class for LLM usage.
"""
prompt_tokens: int # Tokens used for prompt
prompt_unit_price: Decimal # Unit price for prompt
prompt_price_unit: Decimal # Price unit for prompt, i.e., the unit price based on how many tokens
prompt_price: Decimal # Cost for prompt
completion_tokens: int # Tokens used for response
completion_unit_price: Decimal # Unit price for response
completion_price_unit: Decimal # Price unit for response, i.e., the unit price based on how many tokens
completion_price: Decimal # Cost for response
total_tokens: int # Total number of tokens used
total_price: Decimal # Total cost
currency: str # Currency unit
latency: float # Request latency (s)
```
______________________________________________________________________
### TextEmbeddingResult
```python
class TextEmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
model: str # Actual model used
embeddings: list[list[float]] # List of embedding vectors, corresponding to the input texts list
usage: EmbeddingUsage # Usage information
```
### EmbeddingUsage
```python
class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.
"""
tokens: int # Number of tokens used
total_tokens: int # Total number of tokens used
unit_price: Decimal # Unit price
price_unit: Decimal # Price unit, i.e., the unit price based on how many tokens
total_price: Decimal # Total cost
currency: str # Currency unit
latency: float # Request latency (s)
```
______________________________________________________________________
### RerankResult
```python
class RerankResult(BaseModel):
"""
Model class for rerank result.
"""
model: str # Actual model used
docs: list[RerankDocument] # Reranked document list
```
### RerankDocument
```python
class RerankDocument(BaseModel):
"""
Model class for rerank document.
"""
index: int # original index
text: str
score: float
```

View File

@ -1,176 +0,0 @@
## Predefined Model Integration
After completing the vendor integration, the next step is to integrate the models from the vendor.
First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory.
Currently supported model types are:
- `llm` Text Generation Model
- `text_embedding` Text Embedding Model
- `rerank` Rerank Model
- `speech2text` Speech-to-Text
- `tts` Text-to-Speech
- `moderation` Moderation
Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`.
For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`.
### Prepare Model YAML
```yaml
model: claude-2.1 # Model identifier
# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US.
# This can also be omitted, in which case the model identifier will be used as the label
label:
en_US: claude-2.1
model_type: llm # Model type, claude-2.1 is an LLM
features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding
- agent-thought
model_properties: # Model properties
mode: chat # LLM mode, complete for text completion models, chat for conversation models
context_size: 200000 # Maximum context size
parameter_rules: # Parameter rules for the model call; only LLM requires this
- name: temperature # Parameter variable name
# Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
# The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
# Additional configuration parameters will override the default configuration if set
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label: # Display name of the parameter
zh_Hans: 取样数量
en_US: Top k
type: int # Parameter type, supports float/int/string/boolean
help: # Help information, describing the parameter's function
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false # Whether the parameter is mandatory; can be omitted
- name: max_tokens_to_sample
use_template: max_tokens
default: 4096 # Default value of the parameter
min: 1 # Minimum value of the parameter, applicable to float/int only
max: 4096 # Maximum value of the parameter, applicable to float/int only
pricing: # Pricing information
input: '8.00' # Input unit price, i.e., prompt price
output: '24.00' # Output unit price, i.e., response content price
unit: '0.000001' # Price unit, meaning the above prices are per 100K
currency: USD # Price currency
```
It is recommended to prepare all model configurations before starting the implementation of the model code.
You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity).
### Implement the Model Call Code
Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code.
Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
- LLM Call
Implement the core method for calling the LLM, supporting both streaming and synchronous responses.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- Pre-compute Input Tokens
If the model does not provide an interface to precompute tokens, return 0 directly.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
- Validate Model Credentials
Similar to vendor credential validation, but specific to a single model.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
- Map Invoke Errors
When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly.
Runtime Errors:
- `InvokeConnectionError` Connection error
- `InvokeServerUnavailableError` Service provider unavailable
- `InvokeRateLimitError` Rate limit reached
- `InvokeAuthorizationError` Authorization failed
- `InvokeBadRequestError` Parameter error
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).

View File

@ -1,266 +0,0 @@
## Adding a New Provider
Providers support three types of model configuration methods:
- `predefined-model` Predefined model
This indicates that users only need to configure the unified provider credentials to use the predefined models under the provider.
- `customizable-model` Customizable model
Users need to add credential configurations for each model.
- `fetch-from-remote` Fetch from remote
This is consistent with the `predefined-model` configuration method. Only unified provider credentials need to be configured, and models are obtained from the provider through credential information.
These three configuration methods **can coexist**, meaning a provider can support `predefined-model` + `customizable-model` or `predefined-model` + `fetch-from-remote`, etc. In other words, configuring the unified provider credentials allows the use of predefined and remotely fetched models, and if new models are added, they can be used in addition to the custom models.
## Getting Started
Adding a new provider starts with determining the English identifier of the provider, such as `anthropic`, and using this identifier to create a `module` in `model_providers`.
Under this `module`, we first need to prepare the provider's YAML configuration.
### Preparing Provider YAML
Here, using `Anthropic` as an example, we preset the provider's basic information, supported model types, configuration methods, and credential rules.
```YAML
provider: anthropic # Provider identifier
label: # Provider display name, can be set in en_US English and zh_Hans Chinese, zh_Hans will default to en_US if not set.
en_US: Anthropic
icon_small: # Small provider icon, stored in the _assets directory under the corresponding provider implementation directory, same language strategy as label
en_US: icon_s_en.png
icon_large: # Large provider icon, stored in the _assets directory under the corresponding provider implementation directory, same language strategy as label
en_US: icon_l_en.png
supported_model_types: # Supported model types, Anthropic only supports LLM
- llm
configurate_methods: # Supported configuration methods, Anthropic only supports predefined models
- predefined-model
provider_credential_schema: # Provider credential rules, as Anthropic only supports predefined models, unified provider credential rules need to be defined
credential_form_schemas: # List of credential form items
- variable: anthropic_api_key # Credential parameter variable name
label: # Display name
en_US: API Key
type: secret-input # Form type, here secret-input represents an encrypted information input box, showing masked information when editing.
required: true # Whether required
placeholder: # Placeholder information
zh_Hans: Enter your API Key here
en_US: Enter your API Key
- variable: anthropic_api_url
label:
en_US: API URL
type: text-input # Form type, here text-input represents a text input box
required: false
placeholder:
zh_Hans: Enter your API URL here
en_US: Enter your API URL
```
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#provider).
### Implementing Provider Code
Providers need to inherit the `__base.model_provider.ModelProvider` base class and implement the `validate_provider_credentials` method for unified provider credential verification. For reference, see [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py).
> If the provider is the type of `customizable-model`, there is no need to implement the `validate_provider_credentials` method.
```python
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
You can choose any validate_credentials method of model type or implement validate method by yourself,
such as: get model list api
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
```
Of course, you can also preliminarily reserve the implementation of `validate_provider_credentials` and directly reuse it after the model credential verification method is implemented.
______________________________________________________________________
### Adding Models
After the provider integration is complete, the next step is to integrate models under the provider.
First, we need to determine the type of the model to be integrated and create a `module` for the corresponding model type in the provider's directory.
The currently supported model types are as follows:
- `llm` Text generation model
- `text_embedding` Text Embedding model
- `rerank` Rerank model
- `speech2text` Speech to text
- `tts` Text to speech
- `moderation` Moderation
Continuing with `Anthropic` as an example, since `Anthropic` only supports LLM, we create a `module` named `llm` in `model_providers.anthropic`.
For predefined models, we first need to create a YAML file named after the model, such as `claude-2.1.yaml`, under the `llm` `module`.
#### Preparing Model YAML
```yaml
model: claude-2.1 # Model identifier
# Model display name, can be set in en_US English and zh_Hans Chinese, zh_Hans will default to en_US if not set.
# Alternatively, if the label is not set, use the model identifier content.
label:
en_US: claude-2.1
model_type: llm # Model type, claude-2.1 is an LLM
features: # Supported features, agent-thought for Agent reasoning, vision for image understanding
- agent-thought
model_properties: # Model properties
mode: chat # LLM mode, complete for text completion model, chat for dialogue model
context_size: 200000 # Maximum supported context size
parameter_rules: # Model invocation parameter rules, only required for LLM
- name: temperature # Invocation parameter variable name
# Default preset with 5 variable content configuration templates: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
# Directly set the template variable name in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
# If additional configuration parameters are set, they will override the default configuration
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label: # Invocation parameter display name
zh_Hans: Sampling quantity
en_US: Top k
type: int # Parameter type, supports float/int/string/boolean
help: # Help information, describing the role of the parameter
zh_Hans: Only sample from the top K options for each subsequent token.
en_US: Only sample from the top K options for each subsequent token.
required: false # Whether required, can be left unset
- name: max_tokens_to_sample
use_template: max_tokens
default: 4096 # Default parameter value
min: 1 # Minimum parameter value, only applicable for float/int
max: 4096 # Maximum parameter value, only applicable for float/int
pricing: # Pricing information
input: '8.00' # Input price, i.e., Prompt price
output: '24.00' # Output price, i.e., returned content price
unit: '0.000001' # Pricing unit, i.e., the above prices are per 100K
currency: USD # Currency
```
It is recommended to prepare all model configurations before starting the implementation of the model code.
Similarly, you can also refer to the YAML configuration information for corresponding model types of other providers in the `model_providers` directory. The complete YAML rules can be found at: [Schema](schema.md#AIModel).
#### Implementing Model Invocation Code
Next, you need to create a python file named `llm.py` under the `llm` `module` to write the implementation code.
In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguageModel` (arbitrarily), inheriting the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
- LLM Invocation
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
- Pre-calculating Input Tokens
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
- Model Credential Verification
Similar to provider credential verification, this step involves verification for an individual model.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
- Invocation Error Mapping Table
When there is an exception in model invocation, it needs to be mapped to the `InvokeError` type specified by Runtime. This facilitates Dify's ability to handle different errors with appropriate follow-up actions.
Runtime Errors:
- `InvokeConnectionError` Invocation connection error
- `InvokeServerUnavailableError` Invocation service provider unavailable
- `InvokeRateLimitError` Invocation reached rate limit
- `InvokeAuthorizationError` Invocation authorization failure
- `InvokeBadRequestError` Invocation parameter error
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
For details on the interface methods, see: [Interfaces](interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
### Testing
To ensure the availability of integrated providers/models, each method written needs corresponding integration test code in the `tests` directory.
Continuing with `Anthropic` as an example:
Before writing test code, you need to first add the necessary credential environment variables for the test provider in `.env.example`, such as: `ANTHROPIC_API_KEY`.
Before execution, copy `.env.example` to `.env` and then execute.
#### Writing Test Code
Create a `module` with the same name as the provider in the `tests` directory: `anthropic`, and continue to create `test_provider.py` and test py files for the corresponding model types within this module, as shown below:
```shell
.
├── __init__.py
├── anthropic
│   ├── __init__.py
│   ├── test_llm.py # LLM Testing
│   └── test_provider.py # Provider Testing
```
Write test code for all the various cases implemented above and submit the code after passing the tests.

View File

@ -1,208 +0,0 @@
# Configuration Rules
- Provider rules are based on the [Provider](#Provider) entity.
- Model rules are based on the [AIModelEntity](#AIModelEntity) entity.
> All entities mentioned below are based on `Pydantic BaseModel` and can be found in the `entities` module.
### Provider
- `provider` (string) Provider identifier, e.g., `openai`
- `label` (object) Provider display name, i18n, with `en_US` English and `zh_Hans` Chinese language settings
- `zh_Hans` (string) [optional] Chinese label name, if `zh_Hans` is not set, `en_US` will be used by default.
- `en_US` (string) English label name
- `description` (object) Provider description, i18n
- `zh_Hans` (string) [optional] Chinese description
- `en_US` (string) English description
- `icon_small` (string) [optional] Small provider ICON, stored in the `_assets` directory under the corresponding provider implementation directory, with the same language strategy as `label`
- `zh_Hans` (string) Chinese ICON
- `en_US` (string) English ICON
- `icon_large` (string) [optional] Large provider ICON, stored in the `_assets` directory under the corresponding provider implementation directory, with the same language strategy as `label`
- `zh_Hans` (string) Chinese ICON
- `en_US` (string) English ICON
- `background` (string) [optional] Background color value, e.g., #FFFFFF, if empty, the default frontend color value will be displayed.
- `help` (object) [optional] help information
- `title` (object) help title, i18n
- `zh_Hans` (string) [optional] Chinese title
- `en_US` (string) English title
- `url` (object) help link, i18n
- `zh_Hans` (string) [optional] Chinese link
- `en_US` (string) English link
- `supported_model_types` (array\[[ModelType](#ModelType)\]) Supported model types
- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) Configuration methods
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) Provider credential specification
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) Model credential specification
### AIModelEntity
- `model` (string) Model identifier, e.g., `gpt-3.5-turbo`
- `label` (object) [optional] Model display name, i18n, with `en_US` English and `zh_Hans` Chinese language settings
- `zh_Hans` (string) [optional] Chinese label name
- `en_US` (string) English label name
- `model_type` ([ModelType](#ModelType)) Model type
- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] Supported feature list
- `model_properties` (object) Model properties
- `mode` ([LLMMode](#LLMMode)) Mode (available for model type `llm`)
- `context_size` (int) Context size (available for model types `llm`, `text-embedding`)
- `max_chunks` (int) Maximum number of chunks (available for model types `text-embedding`, `moderation`)
- `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`)
- `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`)
- `default_voice` (string) default voice, e.g.alloy,echo,fable,onyx,nova,shimmeravailable for model type `tts`
- `voices` (list) List of available voice.available for model type `tts`
- `mode` (string) voice model.available for model type `tts`
- `name` (string) voice model display name.available for model type `tts`
- `language` (string) the voice model supports languages.available for model type `tts`
- `word_limit` (int) Single conversion word limit, paragraph-wise by defaultavailable for model type `tts`
- `audio_type` (string) Support audio file extension format, e.g.mp3,wavavailable for model type `tts`
- `max_workers` (int) Number of concurrent workers supporting text and audio conversionavailable for model type`tts`
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] Model invocation parameter rules
- `pricing` ([PriceConfig](#PriceConfig)) [optional] Pricing information
- `deprecated` (bool) Whether deprecated. If deprecated, the model will no longer be displayed in the list, but those already configured can continue to be used. Default False.
### ModelType
- `llm` Text generation model
- `text-embedding` Text Embedding model
- `rerank` Rerank model
- `speech2text` Speech to text
- `tts` Text to speech
- `moderation` Moderation
### ConfigurateMethod
- `predefined-model` Predefined model
Indicates that users can use the predefined models under the provider by configuring the unified provider credentials.
- `customizable-model` Customizable model
Users need to add credential configuration for each model.
- `fetch-from-remote` Fetch from remote
Consistent with the `predefined-model` configuration method, only unified provider credentials need to be configured, and models are obtained from the provider through credential information.
### ModelFeature
- `agent-thought` Agent reasoning, generally over 70B with thought chain capability.
- `vision` Vision, i.e., image understanding.
- `tool-call`
- `multi-tool-call`
- `stream-tool-call`
### FetchFrom
- `predefined-model` Predefined model
- `fetch-from-remote` Remote model
### LLMMode
- `complete` Text completion
- `chat` Dialogue
### ParameterRule
- `name` (string) Actual model invocation parameter name
- `use_template` (string) [optional] Using template
By default, 5 variable content configuration templates are preset:
- `temperature`
- `top_p`
- `frequency_penalty`
- `presence_penalty`
- `max_tokens`
In use_template, you can directly set the template variable name, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
No need to set any parameters other than `name` and `use_template`. If additional configuration parameters are set, they will override the default configuration.
Refer to `openai/llm/gpt-3.5-turbo.yaml`.
- `label` (object) [optional] Label, i18n
- `zh_Hans`(string) [optional] Chinese label name
- `en_US` (string) English label name
- `type`(string) [optional] Parameter type
- `int` Integer
- `float` Float
- `string` String
- `boolean` Boolean
- `help` (string) [optional] Help information
- `zh_Hans` (string) [optional] Chinese help information
- `en_US` (string) English help information
- `required` (bool) Required, default False.
- `default`(int/float/string/bool) [optional] Default value
- `min`(int/float) [optional] Minimum value, applicable only to numeric types
- `max`(int/float) [optional] Maximum value, applicable only to numeric types
- `precision`(int) [optional] Precision, number of decimal places to keep, applicable only to numeric types
- `options` (array[string]) [optional] Dropdown option values, applicable only when `type` is `string`, if not set or null, option values are not restricted
### PriceConfig
- `input` (float) Input price, i.e., Prompt price
- `output` (float) Output price, i.e., returned content price
- `unit` (float) Pricing unit, e.g., if the price is measured in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
- `currency` (string) Currency unit
### ProviderCredentialSchema
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard
### ModelCredentialSchema
- `model` (object) Model identifier, variable name defaults to `model`
- `label` (object) Model form item display name
- `en_US` (string) English
- `zh_Hans`(string) [optional] Chinese
- `placeholder` (object) Model prompt content
- `en_US`(string) English
- `zh_Hans`(string) [optional] Chinese
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard
### CredentialFormSchema
- `variable` (string) Form item variable name
- `label` (object) Form item label name
- `en_US`(string) English
- `zh_Hans` (string) [optional] Chinese
- `type` ([FormType](#FormType)) Form item type
- `required` (bool) Whether required
- `default`(string) Default value
- `options` (array\[[FormOption](#FormOption)\]) Specific property of form items of type `select` or `radio`, defining dropdown content
- `placeholder`(object) Specific property of form items of type `text-input`, placeholder content
- `en_US`(string) English
- `zh_Hans` (string) [optional] Chinese
- `max_length` (int) Specific property of form items of type `text-input`, defining maximum input length, 0 for no limit.
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty.
### FormType
- `text-input` Text input component
- `secret-input` Password input component
- `select` Single-choice dropdown
- `radio` Radio component
- `switch` Switch component, only supports `true` and `false` values
### FormOption
- `label` (object) Label
- `en_US`(string) English
- `zh_Hans`(string) [optional] Chinese
- `value` (string) Dropdown option value
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty.
### FormShowOnObject
- `variable` (string) Variable name of other form items
- `value` (string) Variable value of other form items

View File

@ -1,304 +0,0 @@
## 自定义预定义模型接入
### 介绍
供应商集成完成后,接下来为供应商下模型的接入,为了帮助理解整个接入过程,我们以`Xinference`为例,逐步完成一个完整的供应商接入。
需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。
而不同于预定义模型,自定义供应商接入时永远会拥有如下两个参数,不需要在供应商 yaml 中定义。
![Alt text](images/index/image-3.png)
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`Runtime 会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
### 编写供应商 yaml
我们首先要确定,接入的这个供应商支持哪些类型的模型。
当前支持模型类型如下:
- `llm` 文本生成模型
- `text_embedding` 文本 Embedding 模型
- `rerank` Rerank 模型
- `speech2text` 语音转文字
- `tts` 文字转语音
- `moderation` 审查
`Xinference`支持`LLM``Text Embedding`和 Rerank那么我们开始编写`xinference.yaml`
```yaml
provider: xinference #确定供应商标识
label: # 供应商展示名称,可设置 en_US 英文、zh_Hans 中文两种语言zh_Hans 不设置将默认使用 en_US。
en_US: Xorbits Inference
icon_small: # 小图标,可以参考其他供应商的图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
en_US: icon_s_en.svg
icon_large: # 大图标
en_US: icon_l_en.svg
help: # 帮助
title:
en_US: How to deploy Xinference
zh_Hans: 如何部署 Xinference
url:
en_US: https://github.com/xorbitsai/inference
supported_model_types: # 支持的模型类型Xinference 同时支持 LLM/Text Embedding/Rerank
- llm
- text-embedding
- rerank
configurate_methods: # 因为 Xinference 为本地部署的供应商,并且没有预定义模型,需要用什么模型需要根据 Xinference 的文档自己部署,所以这里只支持自定义模型
- customizable-model
provider_credential_schema:
credential_form_schemas:
```
随后,我们需要思考在 Xinference 中定义一个模型需要哪些凭据
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
```yaml
provider_credential_schema:
credential_form_schemas:
- variable: model_type
type: select
label:
en_US: Model type
zh_Hans: 模型类型
required: true
options:
- value: text-generation
label:
en_US: Language Model
zh_Hans: 语言模型
- value: embeddings
label:
en_US: Text Embedding
- value: reranking
label:
en_US: Rerank
```
- 每一个模型都有自己的名称`model_name`,因此需要在这里定义
```yaml
- variable: model_name
type: text-input
label:
en_US: Model name
zh_Hans: 模型名称
required: true
placeholder:
zh_Hans: 填写模型名称
en_US: Input model name
```
- 填写 Xinference 本地部署的地址
```yaml
- variable: server_url
label:
zh_Hans: 服务器 URL
en_US: Server url
type: text-input
required: true
placeholder:
zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
```
- 每个模型都有唯一的 model_uid因此需要在这里定义
```yaml
- variable: model_uid
label:
zh_Hans: 模型 UID
en_US: Model uid
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 Model UID
en_US: Enter the model uid
```
现在,我们就完成了供应商的基础定义。
### 编写模型代码
然后我们以`llm`类型为例,编写`xinference.llm.llm.py`
`llm.py` 中创建一个 Xinference LLM 类,我们取名为 `XinferenceAILargeLanguageModel`(随意),继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下几个方法:
- LLM 调用
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- 预计算输入 tokens
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
有时候,也许你不需要直接返回 0所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的 tokens并确保环境变量`PLUGIN_BASED_TOKEN_COUNTING_ENABLED`设置为`true`,这个方法位于`AIModel`基类中,它会使用 GPT2 的 Tokenizer 进行计算,但是只能作为替代方法,并不完全准确。
- 模型凭据校验
与供应商凭据校验类似,这里针对单个模型进行校验。
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
- 模型参数 Schema
与自定义类型不同,由于没有在 yaml 文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的 Schema。
如 Xinference 支持`max_tokens` `temperature` `top_p` 这三个模型参数。
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例 A 模型支持`top_k`B 模型不支持`top_k`,那么我们需要在这里动态生成模型参数的 Schema如下所示
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature', type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度', en_US='Temperature'
)
),
ParameterRule(
name='top_p', type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P', en_US='Top P'
)
),
ParameterRule(
name='max_tokens', type=ParameterType.INT,
use_template='max_tokens',
min=1,
default=512,
label=I18nObject(
zh_Hans='最大生成长度', en_US='Max Tokens'
)
)
]
# if model is A, add top_k to rules
if model == 'A':
rules.append(
ParameterRule(
name='top_k', type=ParameterType.INT,
use_template='top_k',
min=1,
default=50,
label=I18nObject(
zh_Hans='Top K', en_US='Top K'
)
)
)
"""
some NOT IMPORTANT code here
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=model_type,
model_properties={
ModelPropertyKey.MODE: ModelType.LLM,
},
parameter_rules=rules
)
return entity
```
- 调用异常错误映射表
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
Runtime Errors:
- `InvokeConnectionError` 调用连接错误
- `InvokeServerUnavailableError ` 调用服务方不可用
- `InvokeRateLimitError ` 调用达到限额
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeBadRequestError ` 调用传参有误
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 205 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 385 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 541 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

View File

@ -1,744 +0,0 @@
# 接口方法
这里介绍供应商和各模型类型需要实现的接口方法和参数说明。
## 供应商
继承 `__base.model_provider.ModelProvider` 基类,实现以下接口:
```python
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
You can choose any validate_credentials method of model type or implement validate method by yourself,
such as: get model list api
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
```
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 定义,传入如:`api_key` 等。
验证失败请抛出 `errors.validate.CredentialsValidateFailedError` 错误。
**注:预定义模型需完整实现该接口,自定义模型供应商只需要如下简单实现即可**
```python
class XinferenceProvider(Provider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass
```
## 模型
模型分为 5 种不同的模型类型,不同模型类型继承的基类不同,需要实现的方法也不同。
### 通用接口
所有模型均需要统一实现下面 2 个方法:
- 模型凭据校验
与供应商凭据校验类似,这里针对单个模型进行校验。
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
验证失败请抛出 `errors.validate.CredentialsValidateFailedError` 错误。
- 调用异常错误映射表
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
Runtime Errors:
- `InvokeConnectionError` 调用连接错误
- `InvokeServerUnavailableError ` 调用服务方不可用
- `InvokeRateLimitError ` 调用达到限额
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeBadRequestError ` 调用传参有误
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
也可以直接抛出对应 Errors并做如下定义这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError
],
}
```
可参考 OpenAI `_invoke_error_mapping`。
### LLM
继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下接口:
- LLM 调用
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
- `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) Prompt 列表
若模型为 `Completion` 类型,则列表只需要传入一个 [UserPromptMessage](#UserPromptMessage) 元素即可;
若模型为 `Chat` 类型,需要根据消息不同传入 [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) 元素列表
- `model_parameters` (object) 模型参数
模型参数由模型 YAML 配置的 `parameter_rules` 定义。
- `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] 工具列表,等同于 `function calling` 中的 `function`。
即传入 tool calling 的工具列表。
- `stop` (array[string]) [optional] 停止序列
模型返回将在停止序列定义的字符串之前停止输出。
- `stream` (bool) 是否流式输出,默认 True
流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回
流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。
- 预计算输入 tokens
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
参数说明见上述 `LLM 调用`。
该接口需要根据对应`model`选择合适的`tokenizer`进行计算,如果对应模型没有提供`tokenizer`,可以使用`AIModel`基类中的`_get_num_tokens_by_gpt2(text: str)`方法进行计算。
- 获取自定义模型规则 [可选]
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
```
​当供应商支持增加自定义 LLM 时,可实现此方法让自定义模型可获取模型规则,默认返回 None。
对于`OpenAI`供应商下的大部分微调模型,可以通过其微调模型名称获取到其基类模型,如`gpt-3.5-turbo-1106`,然后返回基类模型的预定义参数规则,参考[openai](https://github.com/langgenius/dify/blob/feat/model-runtime/api/core/model_runtime/model_providers/openai/llm/llm.py#L801)
的具体实现
### TextEmbedding
继承 `__base.text_embedding_model.TextEmbeddingModel` 基类,实现以下接口:
- Embedding 调用
```python
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
- `texts` (array[string]) 文本列表,可批量处理
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
[TextEmbeddingResult](#TextEmbeddingResult) 实体。
- 预计算 tokens
```python
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
```
参数说明见上述 `Embedding 调用`。
同上述`LargeLanguageModel`,该接口需要根据对应`model`选择合适的`tokenizer`进行计算,如果对应模型没有提供`tokenizer`,可以使用`AIModel`基类中的`_get_num_tokens_by_gpt2(text: str)`方法进行计算。
### Rerank
继承 `__base.rerank_model.RerankModel` 基类,实现以下接口:
- rerank 调用
```python
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
- `query` (string) 查询请求内容
- `docs` (array[string]) 需要重排的分段列表
- `score_threshold` (float) [optional] Score 阈值
- `top_n` (int) [optional] 取前 n 个分段
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
[RerankResult](#RerankResult) 实体。
### Speech2text
继承 `__base.speech2text_model.Speech2TextModel` 基类,实现以下接口:
- Invoke 调用
```python
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
- `file` (File) 文件流
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
语音转换后的字符串。
### Text2speech
继承 `__base.text2speech_model.Text2SpeechModel` 基类,实现以下接口:
- Invoke 调用
```python
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: translated audio file
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
- `content_text` (string) 需要转换的文本内容
- `streaming` (bool) 是否进行流式输出
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
文本转换后的语音流。
### Moderation
继承 `__base.moderation_model.ModerationModel` 基类,实现以下接口:
- Invoke 调用
```python
def _invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
```
- 参数:
- `model` (string) 模型名称
- `credentials` (object) 凭据信息
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
- `text` (string) 文本内容
- `user` (string) [optional] 用户的唯一标识符
可以帮助供应商监控和检测滥用行为。
- 返回:
False 代表传入的文本安全True 则反之。
## 实体
### PromptMessageRole
消息角色
```python
class PromptMessageRole(Enum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
```
### PromptMessageContentType
消息内容类型,分为纯文本和图片。
```python
class PromptMessageContentType(Enum):
"""
Enum class for prompt message content type.
"""
TEXT = 'text'
IMAGE = 'image'
```
### PromptMessageContent
消息内容基类,仅作为参数声明用,不可初始化。
```python
class PromptMessageContent(BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
data: str # 内容数据
```
当前支持文本和图片两种类型,可支持同时传入文本和多图。
需要分别初始化 `TextPromptMessageContent` 和 `ImagePromptMessageContent` 传入。
### TextPromptMessageContent
```python
class TextPromptMessageContent(PromptMessageContent):
"""
Model class for text prompt message content.
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
```
若传入图文,其中文字需要构造此实体作为 `content` 列表中的一部分。
### ImagePromptMessageContent
```python
class ImagePromptMessageContent(PromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(Enum):
LOW = 'low'
HIGH = 'high'
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW # 分辨率
```
若传入图文,其中图片需要构造此实体作为 `content` 列表中的一部分
`data` 可以为 `url` 或者图片 `base64` 加密后的字符串。
### PromptMessage
所有 Role 消息体的基类,仅作为参数声明用,不可初始化。
```python
class PromptMessage(BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole # 消息角色
content: Optional[str | list[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
name: Optional[str] = None # 名称,可选。
```
### UserPromptMessage
UserMessage 消息体,代表用户消息。
```python
class UserPromptMessage(PromptMessage):
"""
Model class for user prompt message.
"""
role: PromptMessageRole = PromptMessageRole.USER
```
### AssistantPromptMessage
代表模型返回消息,通常用于 `few-shots` 或聊天历史传入。
```python
class AssistantPromptMessage(PromptMessage):
"""
Model class for assistant prompt message.
"""
class ToolCall(BaseModel):
"""
Model class for assistant prompt message tool call.
"""
class ToolCallFunction(BaseModel):
"""
Model class for assistant prompt message tool call function.
"""
name: str # 工具名称
arguments: str # 工具参数
id: str # 工具 ID仅在 OpenAI tool call 生效,为工具调用的唯一 ID同一个工具可以调用多次
type: str # 默认 function
function: ToolCallFunction # 工具调用信息
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools并且模型认为需要调用工具时返回
```
其中 `tool_calls` 为调用模型传入 `tools` 后,由模型返回的 `tool call` 列表。
### SystemPromptMessage
代表系统消息,通常用于设定给模型的系统指令。
```python
class SystemPromptMessage(PromptMessage):
"""
Model class for system prompt message.
"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
```
### ToolPromptMessage
代表工具消息,用于工具执行后将结果交给模型进行下一步计划。
```python
class ToolPromptMessage(PromptMessage):
"""
Model class for tool prompt message.
"""
role: PromptMessageRole = PromptMessageRole.TOOL
tool_call_id: str # 工具调用 ID若不支持 OpenAI tool call也可传入工具名称
```
基类的 `content` 传入工具执行结果。
### PromptMessageTool
```python
class PromptMessageTool(BaseModel):
"""
Model class for prompt message tool.
"""
name: str # 工具名称
description: str # 工具描述
parameters: dict # 工具参数 dict
```
______________________________________________________________________
### LLMResult
```python
class LLMResult(BaseModel):
"""
Model class for llm result.
"""
model: str # 实际使用模型
prompt_messages: list[PromptMessage] # prompt 消息列表
message: AssistantPromptMessage # 回复消息
usage: LLMUsage # 使用的 tokens 及费用信息
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
```
### LLMResultChunkDelta
流式返回中每个迭代内部 `delta` 实体
```python
class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
"""
index: int # 序号
message: AssistantPromptMessage # 回复消息
usage: Optional[LLMUsage] = None # 使用的 tokens 及费用信息,仅最后一条返回
finish_reason: Optional[str] = None # 结束原因,仅最后一条返回
```
### LLMResultChunk
流式返回中每个迭代实体
```python
class LLMResultChunk(BaseModel):
"""
Model class for llm result chunk.
"""
model: str # 实际使用模型
prompt_messages: list[PromptMessage] # prompt 消息列表
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
delta: LLMResultChunkDelta # 每个迭代存在变化的内容
```
### LLMUsage
```python
class LLMUsage(ModelUsage):
"""
Model class for llm usage.
"""
prompt_tokens: int # prompt 使用 tokens
prompt_unit_price: Decimal # prompt 单价
prompt_price_unit: Decimal # prompt 价格单位,即单价基于多少 tokens
prompt_price: Decimal # prompt 费用
completion_tokens: int # 回复使用 tokens
completion_unit_price: Decimal # 回复单价
completion_price_unit: Decimal # 回复价格单位,即单价基于多少 tokens
completion_price: Decimal # 回复费用
total_tokens: int # 总使用 token 数
total_price: Decimal # 总费用
currency: str # 货币单位
latency: float # 请求耗时 (s)
```
______________________________________________________________________
### TextEmbeddingResult
```python
class TextEmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
model: str # 实际使用模型
embeddings: list[list[float]] # embedding 向量列表,对应传入的 texts 列表
usage: EmbeddingUsage # 使用信息
```
### EmbeddingUsage
```python
class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.
"""
tokens: int # 使用 token 数
total_tokens: int # 总使用 token 数
unit_price: Decimal # 单价
price_unit: Decimal # 价格单位,即单价基于多少 tokens
total_price: Decimal # 总费用
currency: str # 货币单位
latency: float # 请求耗时 (s)
```
______________________________________________________________________
### RerankResult
```python
class RerankResult(BaseModel):
"""
Model class for rerank result.
"""
model: str # 实际使用模型
docs: list[RerankDocument] # 重排后的分段列表
```
### RerankDocument
```python
class RerankDocument(BaseModel):
"""
Model class for rerank document.
"""
index: int # 原序号
text: str # 分段文本内容
score: float # 分数
```

View File

@ -1,172 +0,0 @@
## 预定义模型接入
供应商集成完成后,接下来为供应商下模型的接入。
我们首先需要确定接入模型的类型,并在对应供应商的目录下创建对应模型类型的 `module`
当前支持模型类型如下:
- `llm` 文本生成模型
- `text_embedding` 文本 Embedding 模型
- `rerank` Rerank 模型
- `speech2text` 语音转文字
- `tts` 文字转语音
- `moderation` 审查
依旧以 `Anthropic` 为例,`Anthropic` 仅支持 LLM因此在 `model_providers.anthropic` 创建一个 `llm` 为名称的 `module`
对于预定义的模型,我们首先需要在 `llm` `module` 下创建以模型名为文件名称的 YAML 文件,如:`claude-2.1.yaml`
### 准备模型 YAML
```yaml
model: claude-2.1 # 模型标识
# 模型展示名称,可设置 en_US 英文、zh_Hans 中文两种语言zh_Hans 不设置将默认使用 en_US。
# 也可不设置 label则使用 model 标识内容。
label:
en_US: claude-2.1
model_type: llm # 模型类型claude-2.1 为 LLM
features: # 支持功能agent-thought 为支持 Agent 推理vision 为支持图片理解
- agent-thought
model_properties: # 模型属性
mode: chat # LLM 模式complete 文本补全模型chat 对话模型
context_size: 200000 # 支持最大上下文大小
parameter_rules: # 模型调用参数规则,仅 LLM 需要提供
- name: temperature # 调用参数变量名
# 默认预置了 5 种变量内容配置模板temperature/top_p/max_tokens/presence_penalty/frequency_penalty
# 可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
# 若设置了额外的配置参数,将覆盖默认配置
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label: # 调用参数展示名称
zh_Hans: 取样数量
en_US: Top k
type: int # 参数类型,支持 float/int/string/boolean
help: # 帮助信息,描述参数作用
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false # 是否必填,可不设置
- name: max_tokens_to_sample
use_template: max_tokens
default: 4096 # 参数默认值
min: 1 # 参数最小值,仅 float/int 可用
max: 4096 # 参数最大值,仅 float/int 可用
pricing: # 价格信息
input: '8.00' # 输入单价,即 Prompt 单价
output: '24.00' # 输出单价,即返回内容单价
unit: '0.000001' # 价格单位,即上述价格为每 100K 的单价
currency: USD # 价格货币
```
建议将所有模型配置都准备完毕后再开始模型代码的实现。
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。
### 实现模型调用代码
接下来需要在 `llm` `module` 下创建一个同名的 python 文件 `llm.py` 来编写代码实现。
`llm.py` 中创建一个 Anthropic LLM 类,我们取名为 `AnthropicLargeLanguageModel`(随意),继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下几个方法:
- LLM 调用
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- 预计算输入 tokens
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
- 模型凭据校验
与供应商凭据校验类似,这里针对单个模型进行校验。
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
- 调用异常错误映射表
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
Runtime Errors:
- `InvokeConnectionError` 调用连接错误
- `InvokeServerUnavailableError ` 调用服务方不可用
- `InvokeRateLimitError ` 调用达到限额
- `InvokeAuthorizationError` 调用鉴权失败
- `InvokeBadRequestError ` 调用传参有误
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。

View File

@ -1,192 +0,0 @@
## 增加新供应商
供应商支持三种模型配置方式:
- `predefined-model ` 预定义模型
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
- `customizable-model` 自定义模型
用户需要新增每个模型的凭据配置,如 Xinference它同时支持 LLM 和 Text Embedding但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
- `fetch-from-remote` 从远程获取
`predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
如 OpenAI我们可以基于 gpt-turbo-3.5 来 Fine Tune 多个模型,而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让 DifyRuntime 获取到开发者所有的微调模型并接入 Dify。
这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model``predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。
## 开始
### 介绍
#### 名词解释
- `module`: 一个`module`即为一个 Python Package或者通俗一点称为一个文件夹里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
#### 步骤
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
- 创建供应商 yaml 文件,根据[ProviderSchema](./schema.md#provider)编写
- 创建供应商代码,实现一个`class`
- 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm``text_embedding`
- 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`
- 如果有预定义模型,根据模型名称创建同名的 yaml 文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。
- 编写测试代码,确保功能可用。
### 开始吧
增加一个新的供应商需要先确定供应商的英文标识,如 `anthropic`,使用该标识在 `model_providers` 创建以此为名称的 `module`
在此 `module` 下,我们需要先准备供应商的 YAML 配置。
#### 准备供应商 YAML
此处以 `Anthropic` 为例,预设了供应商基础信息、支持的模型类型、配置方式、凭据规则。
```YAML
provider: anthropic # 供应商标识
label: # 供应商展示名称,可设置 en_US 英文、zh_Hans 中文两种语言zh_Hans 不设置将默认使用 en_US。
en_US: Anthropic
icon_small: # 供应商小图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
en_US: icon_s_en.png
icon_large: # 供应商大图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
en_US: icon_l_en.png
supported_model_types: # 支持的模型类型Anthropic 仅支持 LLM
- llm
configurate_methods: # 支持的配置方式Anthropic 仅支持预定义模型
- predefined-model
provider_credential_schema: # 供应商凭据规则,由于 Anthropic 仅支持预定义模型,则需要定义统一供应商凭据规则
credential_form_schemas: # 凭据表单项列表
- variable: anthropic_api_key # 凭据参数变量名
label: # 展示名称
en_US: API Key
type: secret-input # 表单类型,此处 secret-input 代表加密信息输入框,编辑时只展示屏蔽后的信息。
required: true # 是否必填
placeholder: # PlaceHolder 信息
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: anthropic_api_url
label:
en_US: API URL
type: text-input # 表单类型,此处 text-input 代表文本输入框
required: false
placeholder:
zh_Hans: 在此输入您的 API URL
en_US: Enter your API URL
```
如果接入的供应商提供自定义模型,比如`OpenAI`提供微调模型,那么我们就需要添加[`model_credential_schema`](./schema.md#modelcredentialschema),以`OpenAI`为例:
```yaml
model_credential_schema:
model: # 微调模型名称
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: openai_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: openai_organization
label:
zh_Hans: 组织 ID
en_US: Organization
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的组织 ID
en_US: Enter your Organization ID
- variable: openai_api_base
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base
```
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
#### 实现供应商代码
我们需要在`model_providers`下创建一个同名的 python 文件,如`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`
##### 自定义模型供应商
当供应商为 Xinference 等自定义模型供应商时,可跳过该步骤,仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。
```python
class XinferenceProvider(Provider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass
```
##### 预定义模型供应商
供应商需要继承 `__base.model_provider.ModelProvider` 基类,实现 `validate_provider_credentials` 供应商统一凭据校验方法即可,可参考 [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py)。
```python
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
You can choose any validate_credentials method of model type or implement validate method by yourself,
such as: get model list api
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
```
当然也可以先预留 `validate_provider_credentials` 实现,在模型凭据校验方法实现后直接复用。
#### 增加模型
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
对于预定义模型,我们可以通过简单定义一个 yaml并通过实现调用代码来接入。
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。
______________________________________________________________________
### 测试
为了保证接入供应商/模型的可用性,编写后的每个方法均需要在 `tests` 目录中编写对应的集成测试代码。
依旧以 `Anthropic` 为例。
在编写测试代码前,需要先在 `.env.example` 新增测试供应商所需要的凭据环境变量,如:`ANTHROPIC_API_KEY`
在执行前需要将 `.env.example` 复制为 `.env` 再执行。
#### 编写测试代码
`tests` 目录下创建供应商同名的 `module`: `anthropic`,继续在此模块中创建 `test_provider.py` 以及对应模型类型的 test py 文件,如下所示:
```shell
.
├── __init__.py
├── anthropic
│   ├── __init__.py
│   ├── test_llm.py # LLM 测试
│   └── test_provider.py # 供应商测试
```
针对上面实现的代码的各种情况进行测试代码编写,并测试通过后提交代码。

View File

@ -1,209 +0,0 @@
# 配置规则
- 供应商规则基于 [Provider](#Provider) 实体。
- 模型规则基于 [AIModelEntity](#AIModelEntity) 实体。
> 以下所有实体均基于 `Pydantic BaseModel`,可在 `entities` 模块中找到对应实体。
### Provider
- `provider` (string) 供应商标识,如:`openai`
- `label` (object) 供应商展示名称i18n可设置 `en_US` 英文、`zh_Hans` 中文两种语言
- `zh_Hans ` (string) [optional] 中文标签名,`zh_Hans` 不设置将默认使用 `en_US`
- `en_US` (string) 英文标签名
- `description` (object) [optional] 供应商描述i18n
- `zh_Hans` (string) [optional] 中文描述
- `en_US` (string) 英文描述
- `icon_small` (string) [optional] 供应商小 ICON存储在对应供应商实现目录下的 `_assets` 目录,中英文策略同 `label`
- `zh_Hans` (string) [optional] 中文 ICON
- `en_US` (string) 英文 ICON
- `icon_large` (string) [optional] 供应商大 ICON存储在对应供应商实现目录下的 \_assets 目录,中英文策略同 label
- `zh_Hans `(string) [optional] 中文 ICON
- `en_US` (string) 英文 ICON
- `background` (string) [optional] 背景颜色色值,例:#FFFFFF,为空则展示前端默认色值。
- `help` (object) [optional] 帮助信息
- `title` (object) 帮助标题i18n
- `zh_Hans` (string) [optional] 中文标题
- `en_US` (string) 英文标题
- `url` (object) 帮助链接i18n
- `zh_Hans` (string) [optional] 中文链接
- `en_US` (string) 英文链接
- `supported_model_types` (array\[[ModelType](#ModelType)\]) 支持的模型类型
- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) 配置方式
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) 供应商凭据规格
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) 模型凭据规格
### AIModelEntity
- `model` (string) 模型标识,如:`gpt-3.5-turbo`
- `label` (object) [optional] 模型展示名称i18n可设置 `en_US` 英文、`zh_Hans` 中文两种语言
- `zh_Hans `(string) [optional] 中文标签名
- `en_US` (string) 英文标签名
- `model_type` ([ModelType](#ModelType)) 模型类型
- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] 支持功能列表
- `model_properties` (object) 模型属性
- `mode` ([LLMMode](#LLMMode)) 模式 (模型类型 `llm` 可用)
- `context_size` (int) 上下文大小 (模型类型 `llm` `text-embedding` 可用)
- `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用)
- `file_upload_limit` (int) 文件最大上传限制单位MB。模型类型 `speech2text` 可用)
- `supported_file_extensions` (string) 支持文件扩展格式mp3,mp4模型类型 `speech2text` 可用)
- `default_voice` (string) 缺省音色必选alloy,echo,fable,onyx,nova,shimmer模型类型 `tts` 可用)
- `voices` (list) 可选音色列表。
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
- `language` (string) 音色模型支持语言。(模型类型 `tts` 可用)
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
- `audio_type` (string) 支持音频文件扩展格式mp3,wav模型类型 `tts` 可用)
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
- `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用)
- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] 模型调用参数规则
- `pricing` ([PriceConfig](#PriceConfig)) [optional] 价格信息
- `deprecated` (bool) 是否废弃。若废弃,模型列表将不再展示,但已经配置的可以继续使用,默认 False。
### ModelType
- `llm` 文本生成模型
- `text-embedding` 文本 Embedding 模型
- `rerank` Rerank 模型
- `speech2text` 语音转文字
- `tts` 文字转语音
- `moderation` 审查
### ConfigurateMethod
- `predefined-model ` 预定义模型
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
- `customizable-model` 自定义模型
用户需要新增每个模型的凭据配置。
- `fetch-from-remote` 从远程获取
`predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
### ModelFeature
- `agent-thought` Agent 推理,一般超过 70B 有思维链能力。
- `vision` 视觉,即:图像理解。
- `tool-call` 工具调用
- `multi-tool-call` 多工具调用
- `stream-tool-call` 流式工具调用
### FetchFrom
- `predefined-model` 预定义模型
- `fetch-from-remote` 远程模型
### LLMMode
- `completion` 文本补全
- `chat` 对话
### ParameterRule
- `name` (string) 调用模型实际参数名
- `use_template` (string) [optional] 使用模板
默认预置了 5 种变量内容配置模板:
- `temperature`
- `top_p`
- `frequency_penalty`
- `presence_penalty`
- `max_tokens`
可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
不用设置除 `name``use_template` 之外的所有参数,若设置了额外的配置参数,将覆盖默认配置。
可参考 `openai/llm/gpt-3.5-turbo.yaml`
- `label` (object) [optional] 标签i18n
- `zh_Hans`(string) [optional] 中文标签名
- `en_US` (string) 英文标签名
- `type`(string) [optional] 参数类型
- `int` 整数
- `float` 浮点数
- `string` 字符串
- `boolean` 布尔型
- `help` (string) [optional] 帮助信息
- `zh_Hans` (string) [optional] 中文帮助信息
- `en_US` (string) 英文帮助信息
- `required` (bool) 是否必填,默认 False。
- `default`(int/float/string/bool) [optional] 默认值
- `min`(int/float) [optional] 最小值,仅数字类型适用
- `max`(int/float) [optional] 最大值,仅数字类型适用
- `precision`(int) [optional] 精度,保留小数位数,仅数字类型适用
- `options` (array[string]) [optional] 下拉选项值,仅当 `type``string` 时适用,若不设置或为 null 则不限制选项值
### PriceConfig
- `input` (float) 输入单价,即 Prompt 单价
- `output` (float) 输出单价,即返回内容单价
- `unit` (float) 价格单位,如以 1M tokens 计价,则单价对应的单位 token 数为 `0.000001`
- `currency` (string) 货币单位
### ProviderCredentialSchema
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范
### ModelCredentialSchema
- `model` (object) 模型标识,变量名默认 `model`
- `label` (object) 模型表单项展示名称
- `en_US` (string) 英文
- `zh_Hans`(string) [optional] 中文
- `placeholder` (object) 模型提示内容
- `en_US`(string) 英文
- `zh_Hans`(string) [optional] 中文
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范
### CredentialFormSchema
- `variable` (string) 表单项变量名
- `label` (object) 表单项标签名
- `en_US`(string) 英文
- `zh_Hans` (string) [optional] 中文
- `type` ([FormType](#FormType)) 表单项类型
- `required` (bool) 是否必填
- `default`(string) 默认值
- `options` (array\[[FormOption](#FormOption)\]) 表单项为 `select``radio` 专有属性,定义下拉内容
- `placeholder`(object) 表单项为 `text-input `专有属性,表单项 PlaceHolder
- `en_US`(string) 英文
- `zh_Hans` (string) [optional] 中文
- `max_length` (int) 表单项为`text-input`专有属性定义输入最大长度0 为不限制。
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。
### FormType
- `text-input` 文本输入组件
- `secret-input` 密码输入组件
- `select` 单选下拉
- `radio` Radio 组件
- `switch` 开关组件,仅支持 `true``false`
### FormOption
- `label` (object) 标签
- `en_US`(string) 英文
- `zh_Hans`(string) [optional] 中文
- `value` (string) 下拉选项值
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。
### FormShowOnObject
- `variable` (string) 其他表单项变量名
- `value` (string) 其他表单项变量值

View File

@ -222,6 +222,59 @@ class TencentSpanBuilder:
links=links,
)
@staticmethod
def build_message_llm_span(
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
) -> SpanData:
"""Build LLM span for message traces with detailed LLM attributes."""
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
# Extract model information from `metadata`` or `message_data`
trace_metadata = trace_info.metadata or {}
message_data = trace_info.message_data or {}
model_provider = trace_metadata.get("ls_provider") or (
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
)
model_name = trace_metadata.get("ls_model_name") or (
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
)
inputs_str = str(trace_info.inputs or "")
outputs_str = str(trace_info.outputs or "")
attributes = {
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: str(model_name),
GEN_AI_PROVIDER: str(model_provider),
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
GEN_AI_PROMPT: inputs_str,
GEN_AI_COMPLETION: outputs_str,
INPUT_VALUE: inputs_str,
OUTPUT_VALUE: outputs_str,
}
if trace_info.is_streaming_request:
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
)
@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""

View File

@ -107,9 +107,13 @@ class TencentDataTrace(BaseTraceInstance):
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
self.trace_client.add_span(message_span)
# Add LLM child span with detailed attributes
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
self.trace_client.add_span(llm_span)
self._record_message_llm_metrics(trace_info)
# Record trace duration for entry span

View File

@ -1,4 +1,6 @@
from pydantic import BaseModel
from collections.abc import Mapping
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolParameter
@ -25,3 +27,5 @@ class ApiToolBundle(BaseModel):
icon: str | None = None
# openapi operation
openapi: dict
# output schema
output_schema: Mapping[str, object] = Field(default_factory=dict)

View File

@ -3,6 +3,7 @@ from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.workflow.nodes.base.entities import OutputVariableEntity
class WorkflowToolConfigurationUtils:
@ -24,6 +25,31 @@ class WorkflowToolConfigurationUtils:
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod
def get_workflow_graph_output(cls, graph: Mapping[str, Any]) -> Sequence[OutputVariableEntity]:
"""
get workflow graph output
"""
nodes = graph.get("nodes", [])
outputs_by_variable: dict[str, OutputVariableEntity] = {}
variable_order: list[str] = []
for node in nodes:
if node.get("data", {}).get("type") != "end":
continue
for output in node.get("data", {}).get("outputs", []):
entity = OutputVariableEntity.model_validate(output)
variable = entity.variable
if variable not in variable_order:
variable_order.append(variable)
# Later end nodes override duplicated variable definitions.
outputs_by_variable[variable] = entity
return [outputs_by_variable[variable] for variable in variable_order]
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]

View File

@ -162,6 +162,20 @@ class WorkflowToolProviderController(ToolProviderController):
else:
raise ValueError("variable not found")
# get output schema from workflow
outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph)
reserved_keys = {"json", "text", "files"}
properties = {}
for output in outputs:
if output.variable not in reserved_keys:
properties[output.variable] = {
"type": output.value_type,
"description": "",
}
output_schema = {"type": "object", "properties": properties}
return WorkflowTool(
workflow_as_tool_id=db_provider.id,
entity=ToolEntity(
@ -177,6 +191,7 @@ class WorkflowToolProviderController(ToolProviderController):
llm=db_provider.description,
),
parameters=workflow_tool_parameters,
output_schema=output_schema,
),
runtime=ToolRuntime(
tenant_id=db_provider.tenant_id,

View File

@ -114,6 +114,11 @@ class WorkflowTool(Tool):
for file in files:
yield self.create_file_message(file) # type: ignore
# traverse `outputs` field and create variable messages
for key, value in outputs.items():
if key not in {"text", "json", "files"}:
yield self.create_variable_message(variable_name=key, variable_value=value)
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))

View File

@ -1,17 +1,11 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_pause import WorkflowPauseEntity
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowPauseEntity",
]

View File

@ -1,49 +1,26 @@
from enum import StrEnum, auto
from typing import Annotated, Any, ClassVar, TypeAlias
from typing import Annotated, Literal, TypeAlias
from pydantic import BaseModel, Discriminator, Tag
from pydantic import BaseModel, Field
class _PauseReasonType(StrEnum):
class PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
SCHEDULED_PAUSE = auto()
class _PauseReasonBase(BaseModel):
TYPE: ClassVar[_PauseReasonType]
class HumanInputRequired(BaseModel):
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
# The identifier of the human input node causing the pause.
node_id: str
class HumanInputRequired(_PauseReasonBase):
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
class SchedulingPause(_PauseReasonBase):
TYPE = _PauseReasonType.SCHEDULED_PAUSE
class SchedulingPause(BaseModel):
TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
message: str
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
if isinstance(v, _PauseReasonBase):
return v.TYPE
elif isinstance(v, dict):
reason_type_str = v.get("TYPE")
if reason_type_str is None:
return None
try:
reason_type = _PauseReasonType(reason_type_str)
except ValueError:
return None
return reason_type
else:
# return None if the discriminator value isn't found
return None
PauseReason: TypeAlias = Annotated[
(
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
),
Discriminator(_get_pause_reason_discriminator),
]
PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]

View File

@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
paused: bool = Field(default=False)
pause_reason: PauseReason | None = Field(default=None)
pause_reasons: list[PauseReason] = Field(default_factory=list)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
@ -107,7 +107,7 @@ class GraphExecution:
completed: bool = False
aborted: bool = False
paused: bool = False
pause_reason: PauseReason | None = None
pause_reasons: list[PauseReason] = field(default_factory=list)
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
@ -137,10 +137,8 @@ class GraphExecution:
raise RuntimeError("Cannot pause execution that has completed")
if self.aborted:
raise RuntimeError("Cannot pause execution that has been aborted")
if self.paused:
return
self.paused = True
self.pause_reason = reason
self.pause_reasons.append(reason)
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
@ -195,7 +193,7 @@ class GraphExecution:
completed=self.completed,
aborted=self.aborted,
paused=self.paused,
pause_reason=self.pause_reason,
pause_reasons=self.pause_reasons,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
@ -221,7 +219,7 @@ class GraphExecution:
self.completed = state.completed
self.aborted = state.aborted
self.paused = state.paused
self.pause_reason = state.pause_reason
self.pause_reasons = state.pause_reasons
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {

View File

@ -110,7 +110,13 @@ class EventManager:
"""
with self._lock.write_lock():
self._events.append(event)
self._notify_layers(event)
# NOTE: `_notify_layers` is intentionally called outside the critical section
# to minimize lock contention and avoid blocking other readers or writers.
#
# The public `notify_layers` method also does not use a write lock,
# so protecting `_notify_layers` with a lock here is unnecessary.
self._notify_layers(event)
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
"""

View File

@ -232,7 +232,7 @@ class GraphEngine:
self._graph_execution.start()
else:
self._graph_execution.paused = False
self._graph_execution.pause_reason = None
self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent()
self._event_manager.notify_layers(start_event)
@ -246,11 +246,11 @@ class GraphEngine:
# Handle completion
if self._graph_execution.is_paused:
pause_reason = self._graph_execution.pause_reason
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
pause_reasons = self._graph_execution.pause_reasons
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
# Ensure we have a valid PauseReason for the event
paused_event = GraphRunPausedEvent(
reason=pause_reason,
reasons=pause_reasons,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(paused_event)

View File

@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent):
class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command."""
# reason: str | None = Field(default=None, description="reason for pause")
reason: PauseReason = Field(..., description="reason for pause")
reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs available to the client while the run is paused.",

View File

@ -26,7 +26,6 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
@ -40,7 +39,6 @@ from core.workflow.node_events import (
StreamCompletedEvent,
)
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
@ -66,34 +64,12 @@ if TYPE_CHECKING:
from core.plugin.entities.request import InvokeCredentials
class AgentNode(Node):
class AgentNode(Node[AgentNodeData]):
"""
Agent Node
"""
node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
@ -105,8 +81,8 @@ class AgentNode(Node):
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
agent_strategy_name=self._node_data.agent_strategy_name,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
@ -124,13 +100,13 @@ class AgentNode(Node):
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
node_data=self.node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
node_data=self.node_data,
for_log=True,
strategy=strategy,
)
@ -163,7 +139,7 @@ class AgentNode(Node):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": self._node_data.agent_strategy_name,
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@ -410,7 +386,7 @@ class AgentNode(Node):
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:

View File

@ -2,48 +2,24 @@ from collections.abc import Mapping, Sequence
from typing import Any
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
class AnswerNode(Node):
class AnswerNode(Node[AnswerNodeData]):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer)
files = self._extract_files_from_segments(segments.value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -93,4 +69,4 @@ class AnswerNode(Node):
Returns:
Template instance for this Answer node
"""
return Template.from_answer_template(self._node_data.answer)
return Template.from_answer_template(self.node_data.answer)

View File

@ -5,7 +5,7 @@ from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, field_validator, model_validator
from core.workflow.enums import ErrorStrategy
@ -35,6 +35,45 @@ class VariableSelector(BaseModel):
value_selector: Sequence[str]
class OutputVariableType(StrEnum):
STRING = "string"
NUMBER = "number"
INTEGER = "integer"
SECRET = "secret"
BOOLEAN = "boolean"
OBJECT = "object"
FILE = "file"
ARRAY = "array"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_FILE = "array[file]"
ANY = "any"
ARRAY_ANY = "array[any]"
class OutputVariableEntity(BaseModel):
"""
Output Variable Entity.
"""
variable: str
value_type: OutputVariableType
value_selector: Sequence[str]
@field_validator("value_type", mode="before")
@classmethod
def normalize_value_type(cls, v: Any) -> Any:
"""
Normalize value_type to handle case-insensitive array types.
Converts 'Array[...]' to 'array[...]' for backward compatibility.
"""
if isinstance(v, str) and v.startswith("Array["):
return v.lower()
return v
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"

View File

@ -2,7 +2,7 @@ import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from typing import Any, ClassVar
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
@ -49,12 +49,121 @@ from models.enums import UserFrom
from .entities import BaseNodeData, RetryConfig
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
logger = logging.getLogger(__name__)
class Node:
class Node(Generic[NodeDataT]):
node_type: ClassVar["NodeType"]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
def __init_subclass__(cls, **kwargs: Any) -> None:
"""
Automatically extract and validate the node data type from the generic parameter.
When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method:
1. Inspects `__orig_bases__` to find the `Node[T]` parameterization
2. Extracts `T` (e.g., `MyNodeData`) from the generic argument
3. Validates that `T` is a proper `BaseNodeData` subclass
4. Stores it in `_node_data_type` for automatic hydration in `__init__`
This eliminates the need for subclasses to manually implement boilerplate
accessor methods like `_get_title()`, `_get_error_strategy()`, etc.
How it works:
::
class CodeNode(Node[CodeNodeData]):
│ │
│ └─────────────────────────────────┐
│ │
▼ ▼
┌─────────────────────────────┐ ┌─────────────────────────────────┐
│ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │
│ Node[CodeNodeData], │ │ title: str │
│ ) │ │ desc: str | None │
└──────────────┬──────────────┘ │ ... │
│ └─────────────────────────────────┘
▼ ▲
┌─────────────────────────────┐ │
│ get_origin(base) -> Node │ │
│ get_args(base) -> ( │ │
│ CodeNodeData, │ ──────────────────────┘
│ ) │
└──────────────┬──────────────┘
┌─────────────────────────────┐
│ Validate: │
│ - Is it a type? │
│ - Is it a BaseNodeData │
│ subclass? │
└──────────────┬──────────────┘
┌─────────────────────────────┐
│ cls._node_data_type = │
│ CodeNodeData │
└─────────────────────────────┘
Later, in __init__:
::
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
CodeNodeData instance
(stored in self._node_data)
Example:
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
node_type = NodeType.CODE
# No need to implement _get_title, _get_error_strategy, etc.
"""
super().__init_subclass__(**kwargs)
if cls is Node:
return
node_data_type = cls._extract_node_data_type_from_generic()
if node_data_type is None:
raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype")
cls._node_data_type = node_data_type
@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
"""
Extract the node data type from the generic parameter `Node[T]`.
Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`.
Returns:
The extracted BaseNodeData subtype, or None if not found.
Raises:
TypeError: If the generic argument is invalid (not exactly one argument,
or not a BaseNodeData subtype).
"""
# __orig_bases__ contains the original generic bases before type erasure.
# For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`.
for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined]
origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]`
if origin is Node:
args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]`
if len(args) != 1:
raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument")
candidate = args[0]
if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData):
raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype")
return candidate
return None
def __init__(
self,
@ -63,6 +172,7 @@ class Node:
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
self._graph_init_params = graph_init_params
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
@ -83,8 +193,24 @@ class Node:
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
raise ValueError("Node config data must be a mapping.")
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
self.post_init()
def post_init(self) -> None:
"""Optional hook for subclasses requiring extra initialization."""
return
@property
def graph_init_params(self) -> "GraphInitParams":
return self._graph_init_params
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
@ -273,38 +399,29 @@ class Node:
def retry(self) -> bool:
return False
# Abstract methods that subclasses must implement to provide access
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
...
return self._node_data.error_strategy
@abstractmethod
def _get_retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
...
return self._node_data.retry_config
@abstractmethod
def _get_title(self) -> str:
"""Get the node title."""
...
return self._node_data.title
@abstractmethod
def _get_description(self) -> str | None:
"""Get the node description."""
...
return self._node_data.desc
@abstractmethod
def _get_default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
...
return self._node_data.default_value_dict
@abstractmethod
def get_base_node_data(self) -> BaseNodeData:
"""Get the BaseNodeData object for this node."""
...
return self._node_data
# Public interface properties that delegate to abstract methods
@property
@ -332,6 +449,11 @@ class Node:
"""Get the default values dictionary for this node."""
return self._get_default_value_dict()
@property
def node_data(self) -> NodeDataT:
"""Typed access to this node's configuration data."""
return self._node_data
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
match result.status:
case WorkflowNodeExecutionStatus.FAILED:

View File

@ -9,9 +9,8 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
@ -22,32 +21,9 @@ from .exc import (
)
class CodeNode(Node):
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
@ -70,12 +46,12 @@ class CodeNode(Node):
def _run(self) -> NodeRunResult:
# Get code language
code_language = self._node_data.code_language
code = self._node_data.code
code_language = self.node_data.code_language
code = self.node_data.code
# Get variables
variables = {}
for variable_selector in self._node_data.variables:
for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment):
@ -91,7 +67,7 @@ class CodeNode(Node):
)
# Transform result
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -428,7 +404,7 @@ class CodeNode(Node):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
return self.node_data.retry_config.retry_enabled
@staticmethod
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:

View File

@ -20,9 +20,8 @@ from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
@ -38,42 +37,20 @@ from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError
class DatasourceNode(Node):
class DatasourceNode(Node[DatasourceNodeData]):
"""
Datasource Node
"""
_node_data: DatasourceNodeData
node_type = NodeType.DATASOURCE
execution_type = NodeExecutionType.ROOT
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = DatasourceNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> Generator:
"""
Run the datasource node
"""
node_data = self._node_data
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segement:

View File

@ -25,9 +25,8 @@ from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import DocumentExtractorNodeData
@ -36,7 +35,7 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__)
class DocumentExtractorNode(Node):
class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
@ -44,35 +43,12 @@ class DocumentExtractorNode(Node):
node_type = NodeType.DOCUMENT_EXTRACTOR
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
variable_selector = self._node_data.variable_selector
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:

View File

@ -1,41 +1,14 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.entities import EndNodeData
class EndNode(Node):
class EndNode(Node[EndNodeData]):
node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -47,7 +20,7 @@ class EndNode(Node):
This method runs after streaming is complete (if streaming was enabled).
It collects all output variables and returns them.
"""
output_variables = self._node_data.outputs
output_variables = self.node_data.outputs
outputs = {}
for variable_selector in output_variables:
@ -69,6 +42,6 @@ class EndNode(Node):
Template instance for this End node
"""
outputs_config = [
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs
{"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs
]
return Template.from_end_outputs(outputs_config)

View File

@ -1,7 +1,6 @@
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity
class EndNodeData(BaseNodeData):
@ -9,7 +8,7 @@ class EndNodeData(BaseNodeData):
END Node Data.
"""
outputs: list[VariableSelector]
outputs: list[OutputVariableEntity]
class EndStreamParam(BaseModel):

View File

@ -7,10 +7,10 @@ from configs import dify_config
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from factories import file_factory
@ -31,32 +31,9 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
class HttpRequestNode(Node):
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@ -90,8 +67,8 @@ class HttpRequestNode(Node):
process_data = {}
try:
http_executor = Executor(
node_data=self._node_data,
timeout=self._get_request_timeout(self._node_data),
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
@ -246,4 +223,4 @@ class HttpRequestNode(Node):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
return self.node_data.retry_config.retry_enabled

View File

@ -2,15 +2,14 @@ from collections.abc import Mapping
from typing import Any
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import HumanInputNodeData
class HumanInputNode(Node):
class HumanInputNode(Node[HumanInputNodeData]):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
@ -26,33 +25,10 @@ class HumanInputNode(Node):
"handle",
)
_node_data: HumanInputNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = HumanInputNodeData(**data)
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()
@ -65,17 +41,18 @@ class HumanInputNode(Node):
return self._pause_generator()
def _pause_generator(self):
yield PauseRequestedEvent(reason=HumanInputRequired())
# TODO(QuantumGhost): yield a real form id.
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self._node_data.required_variables:
if not self.node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self._node_data.required_variables:
for selector_str in self.node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
@ -95,7 +72,7 @@ class HumanInputNode(Node):
if handle:
return handle
default_values = self._node_data.default_value_dict
default_values = self.node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key))
if handle:

View File

@ -3,9 +3,8 @@ from typing import Any, Literal
from typing_extensions import deprecated
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.runtime import VariablePool
@ -13,33 +12,10 @@ from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(Node):
class IfElseNode(Node[IfElseNodeData]):
node_type = NodeType.IF_ELSE
execution_type = NodeExecutionType.BRANCH
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -59,8 +35,8 @@ class IfElseNode(Node):
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if self._node_data.cases:
for case in self._node_data.cases:
if self.node_data.cases:
for case in self.node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
@ -86,8 +62,8 @@ class IfElseNode(Node):
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self._node_data.conditions or [],
operator=self._node_data.logical_operator or "and",
conditions=self.node_data.conditions or [],
operator=self.node_data.logical_operator or "and",
)
selected_case_id = "true" if final_result else "false"

View File

@ -14,7 +14,6 @@ from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
@ -36,7 +35,6 @@ from core.workflow.node_events import (
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
@ -60,35 +58,13 @@ logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(LLMUsageTrackingMixin, Node):
class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
"""
Iteration Node.
"""
node_type = NodeType.ITERATION
execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -159,10 +135,10 @@ class IterationNode(LLMUsageTrackingMixin, Node):
)
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
@ -197,7 +173,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
return cast(list[object], iterator_list_value)
def _validate_start_node(self) -> None:
if not self._node_data.start_node_id:
if not self.node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
def _execute_iterations(
@ -207,7 +183,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
if self.node_data.is_parallel:
# Parallel mode execution
yield from self._execute_parallel_iterations(
iterator_list_value=iterator_list_value,
@ -254,7 +230,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
outputs.extend([None] * len(iterator_list_value))
# Determine the number of parallel workers
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
max_workers = min(self.node_data.parallel_nums, len(iterator_list_value))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
@ -310,7 +286,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
except Exception as e:
# Handle errors based on error_handle_mode
match self._node_data.error_handle_mode:
match self.node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
# Cancel remaining futures and re-raise
for f in future_to_index:
@ -323,7 +299,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
outputs[index] = None # Will be filtered later
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[:] = [output for output in outputs if output is not None]
def _execute_single_iteration_parallel(
@ -412,7 +388,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
If flatten_output is True (default), flattens the list if all elements are lists.
"""
# If flatten_output is disabled, return outputs as-is
if not self._node_data.flatten_output:
if not self.node_data.flatten_output:
return outputs
if not outputs:
@ -592,14 +568,14 @@ class IterationNode(LLMUsageTrackingMixin, Node):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector)
result = variable_pool.get(self.node_data.output_selector)
if result is None:
outputs.append(None)
else:
outputs.append(result.to_object())
return
elif isinstance(event, GraphRunFailedEvent):
match self._node_data.error_handle_mode:
match self.node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
raise IterationNodeError(event.error)
case ErrorHandleMode.CONTINUE_ON_ERROR:
@ -650,7 +626,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
# Initialize the iteration graph with the new node factory
iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
)
if not iteration_graph:

View File

@ -1,43 +1,16 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(Node):
class IterationStartNode(Node[IterationStartNodeData]):
"""
Iteration Start Node.
"""
node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -10,9 +10,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
@ -35,34 +34,12 @@ default_retrieval_model = {
}
class KnowledgeIndexNode(Node):
_node_data: KnowledgeIndexNodeData
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
node_type = NodeType.KNOWLEDGE_INDEX
execution_type = NodeExecutionType.RESPONSE
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = KnowledgeIndexNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> NodeRunResult: # type: ignore
node_data = self._node_data
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:

View File

@ -30,14 +30,12 @@ from core.variables import (
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
@ -82,11 +80,9 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@ -118,34 +114,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls):
return "1"
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -186,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -559,7 +534,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,

View File

@ -1,12 +1,11 @@
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import FilterOperator, ListOperatorNodeData, Order
@ -35,32 +34,9 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
return wrapper
class ListOperatorNode(Node):
class ListOperatorNode(Node[ListOperatorNodeData]):
node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -70,9 +46,9 @@ class ListOperatorNode(Node):
process_data: dict[str, Sequence[object]] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None:
error_message = f"Variable not found for selector: {self._node_data.variable}"
error_message = f"Variable not found for selector: {self.node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -91,7 +67,7 @@ class ListOperatorNode(Node):
outputs=outputs,
)
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -105,19 +81,19 @@ class ListOperatorNode(Node):
try:
# Filter
if self._node_data.filter_by.enabled:
if self.node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Extract
if self._node_data.extract_by.enabled:
if self.node_data.extract_by.enabled:
variable = self._extract_slice(variable)
# Order
if self._node_data.order_by.enabled:
if self.node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self._node_data.limit.enabled:
if self.node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
@ -143,7 +119,7 @@ class ListOperatorNode(Node):
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self._node_data.filter_by.conditions:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
@ -182,22 +158,22 @@ class ListOperatorNode(Node):
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC)
variable = variable.model_copy(update={"value": result})
else:
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self._node_data.limit.size]
result = variable.value[: self.node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
if value > len(variable.value):

View File

@ -55,7 +55,6 @@ from core.variables import (
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
@ -69,7 +68,7 @@ from core.workflow.node_events import (
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
@ -100,11 +99,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LLMNode(Node):
class LLMNode(Node[LLMNodeData]):
node_type = NodeType.LLM
_node_data: LLMNodeData
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
@ -139,27 +136,6 @@ class LLMNode(Node):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -176,13 +152,13 @@ class LLMNode(Node):
try:
# init messages template
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self._node_data)
inputs = self._fetch_inputs(node_data=self.node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
# merge inputs
inputs.update(jinja_inputs)
@ -191,9 +167,9 @@ class LLMNode(Node):
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self._node_data.vision.configs.variable_selector,
selector=self.node_data.vision.configs.variable_selector,
)
if self._node_data.vision.enabled
if self.node_data.vision.enabled
else []
)
@ -201,7 +177,7 @@ class LLMNode(Node):
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data=self._node_data)
generator = self._fetch_context(node_data=self.node_data)
context = None
for event in generator:
context = event.context
@ -211,7 +187,7 @@ class LLMNode(Node):
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self._node_data.model,
node_data_model=self.node_data.model,
tenant_id=self.tenant_id,
)
@ -219,13 +195,13 @@ class LLMNode(Node):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self._node_data.memory,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
)
query: str | None = None
if self._node_data.memory:
query = self._node_data.memory.query_prompt_template
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
@ -237,29 +213,29 @@ class LLMNode(Node):
context=context,
memory=memory,
model_config=model_config,
prompt_template=self._node_data.prompt_template,
memory_config=self._node_data.memory,
vision_enabled=self._node_data.vision.enabled,
vision_detail=self._node_data.vision.configs.detail,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
)
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=self._node_data.model,
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=self.node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self._node_data.reasoning_format,
reasoning_format=self.node_data.reasoning_format,
)
structured_output: LLMStructuredOutput | None = None
@ -275,12 +251,12 @@ class LLMNode(Node):
reasoning_content = event.reasoning_content or ""
# For downstream nodes, determine clean text based on reasoning_format
if self._node_data.reasoning_format == "tagged":
if self.node_data.reasoning_format == "tagged":
# Keep <think> tags for backward compatibility
clean_text = result_text
else:
# Extract clean text from <think> tags
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format)
# Process structured output if available from the event.
structured_output = (
@ -1226,7 +1202,7 @@ class LLMNode(Node):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
return self.node_data.retry_config.retry_enabled
def _combine_message_content_with_role(

View File

@ -1,43 +1,16 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(Node):
class LoopEndNode(Node[LoopEndNodeData]):
"""
Loop End Node.
"""
node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
@ -29,7 +28,6 @@ from core.workflow.node_events import (
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
@ -42,36 +40,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LoopNode(LLMUsageTrackingMixin, Node):
class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
"""
Loop Node.
"""
node_type = NodeType.LOOP
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -79,27 +55,27 @@ class LoopNode(LLMUsageTrackingMixin, Node):
def _run(self) -> Generator:
"""Run the node."""
# Get inputs
loop_count = self._node_data.loop_count
break_conditions = self._node_data.break_conditions
logical_operator = self._node_data.logical_operator
loop_count = self.node_data.loop_count
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
inputs = {"loop_count": loop_count}
if not self._node_data.start_node_id:
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
root_node_id = self._node_data.start_node_id
root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
if self._node_data.loop_variables:
if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
if isinstance(var.value, list)
else None,
}
for loop_variable in self._node_data.loop_variables:
for loop_variable in self.node_data.loop_variables:
if loop_variable.value_type not in value_processor:
raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
@ -187,7 +163,7 @@ class LoopNode(LLMUsageTrackingMixin, Node):
yield LoopNextEvent(
index=i + 1,
pre_loop_output=self._node_data.outputs,
pre_loop_output=self.node_data.outputs,
)
self._accumulate_usage(loop_usage)
@ -195,7 +171,7 @@ class LoopNode(LLMUsageTrackingMixin, Node):
yield LoopSucceededEvent(
start_at=start_at,
inputs=inputs,
outputs=self._node_data.outputs,
outputs=self.node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
@ -217,7 +193,7 @@ class LoopNode(LLMUsageTrackingMixin, Node):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self._node_data.outputs,
outputs=self.node_data.outputs,
inputs=inputs,
llm_usage=loop_usage,
)
@ -275,11 +251,11 @@ class LoopNode(LLMUsageTrackingMixin, Node):
if isinstance(event, GraphRunFailedEvent):
raise Exception(event.error)
for loop_var in self._node_data.loop_variables or []:
for loop_var in self.node_data.loop_variables or []:
key, sel = loop_var.label, [self._node_id, loop_var.label]
segment = self.graph_runtime_state.variable_pool.get(sel)
self._node_data.outputs[key] = segment.value if segment else None
self._node_data.outputs["loop_round"] = current_index + 1
self.node_data.outputs[key] = segment.value if segment else None
self.node_data.outputs["loop_round"] = current_index + 1
return reach_break_node

View File

@ -1,43 +1,16 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(Node):
class LoopStartNode(Node[LoopStartNodeData]):
"""
Loop Start Node.
"""
node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -69,17 +69,9 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
node_instance = node_class(
return node_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
# Initialize node with provided data
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data)
return node_instance

View File

@ -27,10 +27,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.runtime import VariablePool
@ -84,36 +83,13 @@ def extract_json(text):
return None
class ParameterExtractorNode(Node):
class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Parameter Extractor Node.
"""
node_type = NodeType.PARAMETER_EXTRACTOR
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
@ -138,7 +114,7 @@ class ParameterExtractorNode(Node):
"""
Run the node.
"""
node_data = self._node_data
node_data = self.node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""

View File

@ -13,14 +13,13 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
@ -44,12 +43,10 @@ if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
class QuestionClassifierNode(Node):
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = NodeType.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_node_data: QuestionClassifierNodeData
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
@ -78,33 +75,12 @@ class QuestionClassifierNode(Node):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls):
return "1"
def _run(self):
node_data = self._node_data
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables

View File

@ -1,41 +1,14 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.start.entities import StartNodeData
class StartNode(Node):
class StartNode(Node[StartNodeData]):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -3,41 +3,17 @@ from typing import Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node):
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
@ -57,14 +33,14 @@ class TemplateTransformNode(Node):
def _run(self) -> NodeRunResult:
# Get variables
variables: dict[str, Any] = {}
for variable_selector in self._node_data.variables:
for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))

View File

@ -16,14 +16,12 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
@ -42,18 +40,13 @@ if TYPE_CHECKING:
from core.workflow.runtime import VariablePool
class ToolNode(Node):
class ToolNode(Node[ToolNodeData]):
"""
Tool Node
"""
node_type = NodeType.TOOL
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ToolNodeData.model_validate(data)
@classmethod
def version(cls) -> str:
return "1"
@ -64,13 +57,11 @@ class ToolNode(Node):
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
node_data = self._node_data
# fetch tool icon
tool_info = {
"provider_type": node_data.provider_type.value,
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
"provider_type": self.node_data.provider_type.value,
"provider_id": self.node_data.provider_id,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
}
# get tool runtime
@ -82,10 +73,10 @@ class ToolNode(Node):
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield StreamCompletedEvent(
@ -104,12 +95,12 @@ class ToolNode(Node):
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
node_data=self.node_data,
for_log=True,
)
# get conversation id
@ -154,7 +145,7 @@ class ToolNode(Node):
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}",
error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
error_type=type(e).__name__,
)
)
@ -164,7 +155,7 @@ class ToolNode(Node):
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=e.to_user_friendly_error(plugin_name=node_data.provider_name),
error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
error_type=type(e).__name__,
)
)
@ -498,24 +489,6 @@ class ToolNode(Node):
return result
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
return self.node_data.retry_config.retry_enabled

View File

@ -1,43 +1,18 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import TriggerEventNodeData
class TriggerEventNode(Node):
class TriggerEventNode(Node[TriggerEventNodeData]):
node_type = NodeType.TRIGGER_PLUGIN
execution_type = NodeExecutionType.ROOT
_node_data: TriggerEventNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TriggerEventNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@ -68,9 +43,9 @@ class TriggerEventNode(Node):
# Get trigger data passed when workflow was triggered
metadata = {
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": self._node_data.provider_id,
"event_name": self._node_data.event_name,
"plugin_unique_identifier": self._node_data.plugin_unique_identifier,
"provider_id": self.node_data.provider_id,
"event_name": self.node_data.event_name,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
},
}
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)

View File

@ -1,42 +1,17 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData
class TriggerScheduleNode(Node):
class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
node_type = NodeType.TRIGGER_SCHEDULE
execution_type = NodeExecutionType.ROOT
_node_data: TriggerScheduleNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TriggerScheduleNodeData(**data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -3,41 +3,17 @@ from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import ContentType, WebhookData
class TriggerWebhookNode(Node):
class TriggerWebhookNode(Node[WebhookData]):
node_type = NodeType.TRIGGER_WEBHOOK
execution_type = NodeExecutionType.ROOT
_node_data: WebhookData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = WebhookData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@ -108,7 +84,7 @@ class TriggerWebhookNode(Node):
webhook_headers = webhook_data.get("headers", {})
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
for header in self._node_data.headers:
for header in self.node_data.headers:
header_name = header.name
value = _get_normalized(webhook_headers, header_name)
if value is None:
@ -117,20 +93,20 @@ class TriggerWebhookNode(Node):
outputs[sanitized_name] = value
# Extract configured query parameters
for param in self._node_data.params:
for param in self.node_data.params:
param_name = param.name
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
# Extract configured body parameters
for body_param in self._node_data.body:
for body_param in self.node_data.body:
param_name = body_param.name
param_type = body_param.type
if self._node_data.content_type == ContentType.TEXT:
if self.node_data.content_type == ContentType.TEXT:
# For text/plain, the entire body is a single string parameter
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self._node_data.content_type == ContentType.BINARY:
elif self.node_data.content_type == ContentType.BINARY:
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
continue

View File

@ -23,12 +23,11 @@ class AdvancedSettings(BaseModel):
groups: list[Group]
class VariableAssignerNodeData(BaseNodeData):
class VariableAggregatorNodeData(BaseNodeData):
"""
Variable Assigner Node Data.
Variable Aggregator Node Data.
"""
type: str = "variable-assigner"
output_type: str
variables: list[list[str]]
advanced_settings: AdvancedSettings | None = None

View File

@ -1,40 +1,15 @@
from collections.abc import Mapping
from typing import Any
from core.variables.segments import Segment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorNodeData
class VariableAggregatorNode(Node):
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
node_type = NodeType.VARIABLE_AGGREGATOR
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -44,8 +19,8 @@ class VariableAggregatorNode(Node):
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {}
if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
for selector in self._node_data.variables:
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable}
@ -53,7 +28,7 @@ class VariableAggregatorNode(Node):
inputs = {".".join(selector[1:]): variable.to_object()}
break
else:
for group in self._node_data.advanced_settings.groups:
for group in self.node_data.advanced_settings.groups:
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)

View File

@ -5,9 +5,8 @@ from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
@ -22,33 +21,10 @@ if TYPE_CHECKING:
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(Node):
class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
_node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def __init__(
self,
id: str,
@ -93,21 +69,21 @@ class VariableAssignerNode(Node):
return mapping
def _run(self) -> NodeRunResult:
assigned_variable_selector = self._node_data.assigned_variable_selector
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")
match self._node_data.write_mode:
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]

View File

@ -7,9 +7,8 @@ from core.variables import SegmentType, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
@ -51,32 +50,9 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
mapping[key] = selector
class VariableAssignerNode(Node):
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
@ -84,7 +60,7 @@ class VariableAssignerNode(Node):
Returns True if this node updates any of the requested conversation variables.
"""
# Check each item in this Variable Assigner node
for item in self._node_data.items:
for item in self.node_data.items:
# Convert the item's variable_selector to tuple for comparison
item_selector_tuple = tuple(item.variable_selector)
@ -119,13 +95,13 @@ class VariableAssignerNode(Node):
return var_mapping
def _run(self) -> NodeRunResult:
inputs = self._node_data.model_dump()
inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variable_selectors: list[Sequence[str]] = []
try:
for item in self._node_data.items:
for item in self.node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part

View File

@ -10,6 +10,7 @@ from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.runtime.variable_pool import VariablePool
@ -46,7 +47,11 @@ class ReadyQueueProtocol(Protocol):
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate."""
"""Structural interface for graph execution aggregate.
Defines the minimal set of attributes and methods required from a GraphExecution entity
for runtime orchestration and state management.
"""
workflow_id: str
started: bool
@ -54,6 +59,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""

View File

@ -159,7 +159,6 @@ class WorkflowEntry:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(node_config_data)
try:
# variable selector to variable mapping
@ -303,7 +302,6 @@ class WorkflowEntry:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(node_data)
try:
# variable selector to variable mapping

View File

@ -1,5 +1,6 @@
import mimetypes
import os
import re
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
@ -268,15 +269,47 @@ def _build_from_remote_url(
def _extract_filename(url_path: str, content_disposition: str | None) -> str | None:
filename = None
filename: str | None = None
# Try to extract from Content-Disposition header first
if content_disposition:
_, params = parse_options_header(content_disposition)
# RFC 5987 https://datatracker.ietf.org/doc/html/rfc5987: filename* takes precedence over filename
filename = params.get("filename*") or params.get("filename")
# Manually extract filename* parameter since parse_options_header doesn't support it
filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
if filename_star_match:
raw_star = filename_star_match.group(1).strip()
# Remove trailing quotes if present
raw_star = raw_star.removesuffix('"')
# format: charset'lang'value
try:
parts = raw_star.split("'", 2)
charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8"
value = parts[2] if len(parts) == 3 else parts[-1]
filename = urllib.parse.unquote(value, encoding=charset, errors="replace")
except Exception:
# Fallback: try to extract value after the last single quote
if "''" in raw_star:
filename = urllib.parse.unquote(raw_star.split("''")[-1])
else:
filename = urllib.parse.unquote(raw_star)
if not filename:
# Fallback to regular filename parameter
_, params = parse_options_header(content_disposition)
raw = params.get("filename")
if raw:
# Strip surrounding quotes and percent-decode if present
if len(raw) >= 2 and raw[0] == raw[-1] == '"':
raw = raw[1:-1]
filename = urllib.parse.unquote(raw)
# Fallback to URL path if no filename from header
if not filename:
filename = os.path.basename(url_path)
candidate = os.path.basename(url_path)
filename = urllib.parse.unquote(candidate) if candidate else None
# Defense-in-depth: ensure basename only
if filename:
filename = os.path.basename(filename)
# Return None if filename is empty or only whitespace
if not filename or not filename.strip():
filename = None
return filename or None

View File

@ -0,0 +1,41 @@
"""Add workflow_pauses_reasons table
Revision ID: 7bb281b7a422
Revises: 09cfdda155d1
Create Date: 2025-11-18 18:59:26.999572
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7bb281b7a422"
down_revision = "09cfdda155d1"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"workflow_pause_reasons",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("pause_id", models.types.StringUUID(), nullable=False),
sa.Column("type_", sa.String(20), nullable=False),
sa.Column("form_id", sa.String(length=36), nullable=False),
sa.Column("node_id", sa.String(length=255), nullable=False),
sa.Column("message", sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("workflow_pause_reasons_pkey")),
)
with op.batch_alter_table("workflow_pause_reasons", schema=None) as batch_op:
batch_op.create_index(batch_op.f("workflow_pause_reasons_pause_id_idx"), ["pause_id"], unique=False)
def downgrade():
op.drop_table("workflow_pause_reasons")

Some files were not shown because too many files have changed in this diff Show More