mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 11:45:05 +08:00
Compare commits
143 Commits
1.11.3
...
feat/queue
| Author | SHA1 | Date | |
|---|---|---|---|
| 1cf788c43b | |||
| 73a7756350 | |||
| 02d15ebd5a | |||
| 976b3b5e83 | |||
| b5684f1992 | |||
| bd13cf05eb | |||
| 5f263147f9 | |||
| b68afdfa64 | |||
| da87fce751 | |||
| d5342927d0 | |||
| 754d790c89 | |||
| a099a35e51 | |||
| 2dd893e60d | |||
| b8ee1d4697 | |||
| b4ef1de30f | |||
| 0f15a2baca | |||
| 4cdc19fd05 | |||
| efa5f35277 | |||
| 766fda395b | |||
| b0e815c3c7 | |||
| 462ba354a4 | |||
| 3c668e4a5c | |||
| 872cff7bab | |||
| 8fb69429f9 | |||
| 85064bd8cf | |||
| ba5df3612b | |||
| a923ab1ab8 | |||
| b4c1766932 | |||
| 00a1af8506 | |||
| f56fccee9d | |||
| b6b98a2c8e | |||
| 7e69403dda | |||
| 9796cede72 | |||
| 836ed1f380 | |||
| 80f39963f1 | |||
| 9cf2b2b231 | |||
| 2a97a69825 | |||
| f17c71e08a | |||
| d52621fce3 | |||
| e060d7c28c | |||
| ea5dfe41d5 | |||
| a23c8fcb1a | |||
| e0e82fbfaa | |||
| 1c9f40f92a | |||
| 6ffa2ebabf | |||
| 95dc1e2fe8 | |||
| 6fe7cf5ebf | |||
| a1e8ac4c96 | |||
| b46858d87d | |||
| 5ab6838849 | |||
| ef974e484b | |||
| 299141ae01 | |||
| cc1d437dc1 | |||
| 7aef0b54e5 | |||
| 3c28936796 | |||
| 81fdc7c54b | |||
| abb53f11ad | |||
| d9aa0ec046 | |||
| 6c3302a192 | |||
| 7ba1f0a046 | |||
| 2adf5d0eee | |||
| 103a9a4e67 | |||
| 15b3443e9e | |||
| 81e9d6f63a | |||
| 9c2943183e | |||
| f6a2a09815 | |||
| e229510e73 | |||
| 36048d1526 | |||
| aff7ca12b8 | |||
| ad9eed2551 | |||
| 07109846e0 | |||
| 2aeaefccec | |||
| 4d63bd2083 | |||
| 226f14a20f | |||
| 2b28aed4e2 | |||
| 938a845852 | |||
| ead8568bfc | |||
| ed22d04ea0 | |||
| 04bbf540d9 | |||
| 657c27ec75 | |||
| 16e9cd5ac5 | |||
| 61c79b0013 | |||
| 8332472944 | |||
| fe3f03e50a | |||
| 9c96b23d55 | |||
| 8c97937cae | |||
| f6acff4cce | |||
| 3fa48cb5cf | |||
| b81745aed8 | |||
| 8c41d95d03 | |||
| 9d004a0971 | |||
| 02fcd08c08 | |||
| 77a9a73d0d | |||
| 1770b93e5b | |||
| d8ff4aa9ba | |||
| 9f8f21bf87 | |||
| 0b0dc63f29 | |||
| 8433cf4437 | |||
| bb5d52539c | |||
| 88622f70fb | |||
| 0fdb1b2bc9 | |||
| a5cb9d2b73 | |||
| 64c1234724 | |||
| 202fdfcb81 | |||
| e2f4c9ba8d | |||
| 546d75d84d | |||
| a8fe4ea802 | |||
| 82193580de | |||
| 1fd27cf3ad | |||
| 11d32ca87d | |||
| 5415d0c6d1 | |||
| d8af8ae4e6 | |||
| 04e5d4692f | |||
| 3aa48efd0a | |||
| 8eb78c04b2 | |||
| 22ee318cf8 | |||
| f2bc4f5d87 | |||
| d7d456349d | |||
| dce4d0ff80 | |||
| 3dee8064ba | |||
| bfbb36756a | |||
| d7e0c5f759 | |||
| c396788128 | |||
| e3a7b1f691 | |||
| 8aab7f49c3 | |||
| 1e12c1cbf2 | |||
| affedd6ce4 | |||
| ef21097774 | |||
| 1d377fe994 | |||
| c82697f267 | |||
| 98b25c0bbc | |||
| 1cd0792606 | |||
| 7cbf4093f4 | |||
| 8129ca7c05 | |||
| 65617f000d | |||
| 635eff2e25 | |||
| 55085a9ca2 | |||
| 9dc1e9724e | |||
| c3f66e2901 | |||
| 86e7cb713c | |||
| 0f29244459 | |||
| 48cbf4c78f | |||
| 8c35663220 |
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@ -12,7 +12,6 @@ permissions:
|
||||
statuses: write
|
||||
contents: read
|
||||
|
||||
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
@ -44,6 +43,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Import Linter
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --directory api --dev lint-imports
|
||||
|
||||
- name: Run Basedpyright Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: dev/basedpyright-check
|
||||
|
||||
@ -461,6 +461,16 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
# Workflow storage configuration
|
||||
# Options: rdbms, hybrid
|
||||
# rdbms: Use only the relational database (default)
|
||||
|
||||
105
api/.importlinter
Normal file
105
api/.importlinter
Normal file
@ -0,0 +1,105 @@
|
||||
[importlinter]
|
||||
root_packages =
|
||||
core
|
||||
configs
|
||||
controllers
|
||||
models
|
||||
tasks
|
||||
services
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
type=layers
|
||||
layers =
|
||||
graph_engine
|
||||
graph_events
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||
|
||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:worker]
|
||||
name = Worker
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
worker
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:graph-engine-architecture]
|
||||
name = Graph Engine Architecture
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
orchestration
|
||||
command_processing
|
||||
event_management
|
||||
error_handler
|
||||
graph_traversal
|
||||
graph_state_manager
|
||||
worker_management
|
||||
domain
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:domain-isolation]
|
||||
name = Domain Model Isolation
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.domain
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
core.workflow.graph_engine.command_channels
|
||||
core.workflow.graph_engine.layers
|
||||
core.workflow.graph_engine.protocols
|
||||
|
||||
[importlinter:contract:worker-management]
|
||||
name = Worker Management
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.orchestration
|
||||
core.workflow.graph_engine.command_processing
|
||||
core.workflow.graph_engine.event_management
|
||||
|
||||
|
||||
[importlinter:contract:graph-traversal-components]
|
||||
name = Graph Traversal Components
|
||||
type = layers
|
||||
layers =
|
||||
edge_processor
|
||||
skip_propagator
|
||||
containers =
|
||||
core.workflow.graph_engine.graph_traversal
|
||||
|
||||
[importlinter:contract:command-channels]
|
||||
name = Command Channels Independence
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
@ -14,7 +14,6 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
@ -32,6 +31,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
|
||||
@ -535,6 +535,28 @@ class WorkflowConfig(BaseSettings):
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
description="Maximum number of workers per GraphEngine instance",
|
||||
default=10,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
|
||||
description="Queue depth threshold that triggers worker scale up",
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
|
||||
description="Seconds of idle time before scaling down workers",
|
||||
default=5.0,
|
||||
ge=0.1,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -29,7 +29,7 @@ def no_key_cache_key(namespace: str, key: str) -> str:
|
||||
|
||||
|
||||
# Returns whether the obtained value is obtained, and None if it does not
|
||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
|
||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any:
|
||||
if namespace_cache:
|
||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||
if kv_data is None:
|
||||
|
||||
@ -16,7 +16,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
@ -205,9 +208,6 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
from models import App, db
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
|
||||
@ -20,6 +20,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
@ -536,7 +537,12 @@ class WorkflowTaskStopApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
|
||||
@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
@ -20,7 +20,6 @@ from controllers.console.wraps import (
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
@ -33,6 +32,7 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models.model import AppMode, InstalledApp
|
||||
@ -82,6 +83,11 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
|
||||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
|
||||
@ -26,7 +26,8 @@ from core.errors.error import (
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
@ -262,7 +263,12 @@ class WorkflowTaskStopApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
|
||||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
@ -21,6 +21,7 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -112,6 +113,11 @@ class WorkflowTaskStopApi(WebApiResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -4,8 +4,8 @@ from typing import Any
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
from models.workflow import ConversationVariable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -78,23 +79,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -146,16 +153,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
@ -167,11 +185,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
@ -31,14 +31,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
@ -65,8 +60,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
@ -387,9 +382,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
@ -434,32 +427,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
@ -751,8 +718,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
@ -800,8 +765,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
@ -814,17 +777,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# For unhandled events, we continue (original behavior)
|
||||
return
|
||||
|
||||
@ -848,11 +800,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
case QueueTextChunkEvent():
|
||||
yield from self._handle_text_chunk_event(
|
||||
event, tts_publisher=tts_publisher, queue_message=queue_message
|
||||
)
|
||||
|
||||
case QueueErrorEvent():
|
||||
yield from self._handle_error_event(event)
|
||||
break
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
|
||||
@ -127,6 +127,21 @@ class AppQueueManager:
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag without user permission check.
|
||||
This method allows stopping workflows without user context.
|
||||
|
||||
:param task_id: The task ID to stop
|
||||
:return:
|
||||
"""
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -16,14 +16,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
@ -36,18 +31,16 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
@ -171,11 +164,10 @@ class WorkflowResponseConverter:
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
@ -183,11 +175,7 @@ class WorkflowResponseConverter:
|
||||
def workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> NodeFinishStreamResponse | None:
|
||||
@ -221,9 +209,6 @@ class WorkflowResponseConverter:
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
@ -275,50 +260,6 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunStartedEvent,
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
@ -333,13 +274,11 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -357,15 +296,10 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
@ -384,8 +318,8 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=json_converter.to_json_encodable(event.outputs),
|
||||
title=event.node_title,
|
||||
outputs=json_converter.to_json_encodable(event.outputs) or {},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
@ -394,12 +328,10 @@ class WorkflowResponseConverter:
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -413,7 +345,7 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
@ -437,7 +369,7 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
@ -445,7 +377,6 @@ class WorkflowResponseConverter:
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
@ -463,8 +394,8 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
|
||||
title=event.node_title,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs) or {},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
@ -473,7 +404,7 @@ class WorkflowResponseConverter:
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
|
||||
@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@ -67,7 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -81,7 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@ -94,7 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@ -186,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@ -200,7 +195,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
@ -214,7 +208,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
@ -237,7 +230,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
@ -434,17 +426,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
) -> None:
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
@ -474,7 +456,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,7 +32,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
):
|
||||
@ -41,7 +41,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
|
||||
@ -52,24 +51,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -92,15 +97,27 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
@ -112,10 +129,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Union
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@ -349,9 +345,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
@ -370,32 +364,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
@ -617,8 +585,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
@ -633,7 +599,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _dispatch_event(
|
||||
self,
|
||||
event: Any,
|
||||
event: AppQueueEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
@ -660,8 +626,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
@ -674,17 +638,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle workflow failed and stop events with isinstance check
|
||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||
yield from self._handle_workflow_failed_and_stop_events(
|
||||
|
||||
@ -2,6 +2,7 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -13,14 +14,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
@ -28,42 +24,39 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeInLoopFailedEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@ -79,7 +72,14 @@ class WorkflowBasedAppRunner:
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
@ -91,8 +91,28 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
@ -104,6 +124,7 @@ class WorkflowBasedAppRunner:
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
@ -145,8 +166,25 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
@ -201,6 +239,7 @@ class WorkflowBasedAppRunner:
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
@ -242,8 +281,25 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
@ -310,39 +366,32 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
@ -350,44 +399,29 @@ class WorkflowBasedAppRunner:
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
start_at=event.start_at,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
start_at=event.start_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@ -396,34 +430,18 @@ class WorkflowBasedAppRunner:
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
@ -434,93 +452,21 @@ class WorkflowBasedAppRunner:
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInLoopFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInLoopFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
@ -533,10 +479,10 @@ class WorkflowBasedAppRunner:
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, AgentLogEvent):
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.id,
|
||||
id=event.message_id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
@ -547,51 +493,13 @@ class WorkflowBasedAppRunner:
|
||||
node_id=event.node_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
@ -599,55 +507,41 @@ class WorkflowBasedAppRunner:
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self._publish_event(
|
||||
QueueLoopStartEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
@ -655,42 +549,32 @@ class WorkflowBasedAppRunner:
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self._publish_event(
|
||||
QueueLoopNextEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_loop_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
|
||||
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueLoopCompletedEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
|
||||
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -3,15 +3,13 @@ from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
|
||||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
AGENT_LOG = "agent_log"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
@ -80,21 +75,13 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
node_title: str
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_title: str
|
||||
node_run_index: int
|
||||
output: Any | None = None # output for the current iteration
|
||||
duration: float | None = None
|
||||
output: Any = None # output for the current iteration
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
@ -134,21 +110,13 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
node_title: str
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
error: str | None = None
|
||||
@ -163,7 +131,7 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -175,9 +143,9 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueueLoopNextEvent(AppQueueEvent):
|
||||
@ -191,7 +159,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -203,8 +171,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Any | None = None # output for the current loop
|
||||
duration: float | None = None
|
||||
output: Any = None # output for the current loop
|
||||
|
||||
|
||||
class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
@ -217,7 +184,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -229,9 +196,9 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
error: str | None = None
|
||||
@ -332,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
outputs: dict[str, Any] | None = None
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
@ -352,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
||||
exceptions_count: int
|
||||
outputs: dict[str, Any] | None = None
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueueNodeStartedEvent(AppQueueEvent):
|
||||
@ -364,27 +331,24 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
predecessor_node_id: str | None = None
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
provider_id: str
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
@ -396,7 +360,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -411,16 +374,12 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str | None = None
|
||||
"""single iteration duration map"""
|
||||
iteration_duration_map: dict[str, float] | None = None
|
||||
"""single loop duration map"""
|
||||
loop_duration_map: dict[str, float] | None = None
|
||||
|
||||
|
||||
class QueueAgentLogEvent(AppQueueEvent):
|
||||
@ -436,7 +395,7 @@ class QueueAgentLogEvent(AppQueueEvent):
|
||||
error: str | None = None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
node_id: str
|
||||
|
||||
|
||||
@ -445,81 +404,15 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInLoopFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeExceptionEvent entity
|
||||
@ -530,7 +423,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
@ -545,9 +437,9 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
@ -563,24 +455,16 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
@ -610,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ERROR
|
||||
error: Any | None = None
|
||||
error: Any = None
|
||||
|
||||
|
||||
class QueuePingEvent(AppQueueEvent):
|
||||
@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
error: str
|
||||
|
||||
@ -6,8 +6,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
@ -138,7 +138,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||
id: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
files: Sequence[Mapping[str, Any]] | None = None
|
||||
|
||||
|
||||
@ -175,7 +175,7 @@ class AgentThoughtStreamResponse(StreamResponse):
|
||||
thought: str | None = None
|
||||
observation: str | None = None
|
||||
tool: str | None = None
|
||||
tool_labels: dict | None = None
|
||||
tool_labels: Mapping[str, object] = Field(default_factory=dict)
|
||||
tool_input: str | None = None
|
||||
message_files: list[str] | None = None
|
||||
|
||||
@ -228,7 +228,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_by: dict | None = None
|
||||
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
finished_at: int
|
||||
exceptions_count: int | None = 0
|
||||
@ -257,7 +257,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
@ -436,54 +436,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
status: str
|
||||
error: str | None = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
@ -499,11 +451,9 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
@ -526,12 +476,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_iteration_output: Any | None = None
|
||||
extras: dict = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parallel_mode_run_id: str | None = None
|
||||
duration: float | None = None
|
||||
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@ -552,19 +497,17 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Mapping | None = None
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
execution_metadata: Mapping | None = None
|
||||
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
@ -586,9 +529,9 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
@ -613,12 +556,11 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_loop_output: Any | None = None
|
||||
extras: dict = Field(default_factory=dict)
|
||||
pre_loop_output: Any = None
|
||||
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parallel_mode_run_id: str | None = None
|
||||
duration: float | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||
workflow_run_id: str
|
||||
@ -639,15 +581,15 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Mapping | None = None
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
execution_metadata: Mapping | None = None
|
||||
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: str | None = None
|
||||
@ -757,7 +699,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
answer: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
|
||||
data: Data
|
||||
@ -777,7 +719,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||
mode: str
|
||||
message_id: str
|
||||
answer: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
|
||||
data: Data
|
||||
@ -825,7 +767,7 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
error: str | None = None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
node_id: str
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
|
||||
@ -109,7 +109,9 @@ class AppGeneratorTTSPublisher:
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
if message.event.outputs is None:
|
||||
continue
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
output = message.event.outputs.get("output", "")
|
||||
if isinstance(output, str):
|
||||
self.msg_text += output
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.max_sentence, 7):
|
||||
@ -119,7 +121,7 @@ class AppGeneratorTTSPublisher:
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
if isinstance(text_tmp, str):
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ""
|
||||
|
||||
@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import (
|
||||
@ -41,6 +40,7 @@ from models.provider import (
|
||||
ProviderType,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -704,6 +704,7 @@ class ProviderConfiguration(BaseModel):
|
||||
Get custom model credentials.
|
||||
"""
|
||||
# get provider model
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
@ -1203,6 +1204,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
@ -1286,6 +1288,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
|
||||
@ -1,9 +1,33 @@
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeNodeProvider(BaseModel):
|
||||
class VariableConfig(TypedDict):
|
||||
variable: str
|
||||
value_selector: Sequence[str | int]
|
||||
|
||||
|
||||
class OutputConfig(TypedDict):
|
||||
type: str
|
||||
children: None
|
||||
|
||||
|
||||
class CodeConfig(TypedDict):
|
||||
variables: Sequence[VariableConfig]
|
||||
code_language: str
|
||||
code: str
|
||||
outputs: Mapping[str, OutputConfig]
|
||||
|
||||
|
||||
class DefaultConfig(TypedDict):
|
||||
type: str
|
||||
config: CodeConfig
|
||||
|
||||
|
||||
class CodeNodeProvider(BaseModel, ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_language() -> str:
|
||||
@ -22,11 +46,14 @@ class CodeNodeProvider(BaseModel):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls):
|
||||
def get_default_config(cls) -> DefaultConfig:
|
||||
return {
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
|
||||
"variables": [
|
||||
{"variable": "arg1", "value_selector": []},
|
||||
{"variable": "arg2", "value_selector": []},
|
||||
],
|
||||
"code_language": cls.get_language(),
|
||||
"code": cls.get_default_code(),
|
||||
"outputs": {"result": {"type": "string", "children": None}},
|
||||
|
||||
@ -16,8 +16,8 @@ def full_mask_token(token_length=20):
|
||||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
|
||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
|
||||
@ -28,8 +28,9 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from models import App, Message, WorkflowNodeExecutionModel, db
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from extensions.ext_database import db
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -160,7 +160,7 @@ class ErrorData(BaseModel):
|
||||
sentence.
|
||||
"""
|
||||
|
||||
data: Any | None = None
|
||||
data: Any = None
|
||||
"""
|
||||
Additional information about the error. The value of this member is defined by the
|
||||
sender (e.g. detailed error information, nested errors etc.).
|
||||
|
||||
@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
@ -52,6 +51,8 @@ class AIModel(BaseModel):
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
@ -139,6 +140,8 @@ class AIModel(BaseModel):
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
# sort credentials
|
||||
|
||||
@ -22,7 +22,6 @@ from core.model_runtime.entities.model_entities import (
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -142,6 +141,8 @@ class LargeLanguageModel(AIModel):
|
||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
result = plugin_model_manager.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -340,6 +341,8 @@ class LargeLanguageModel(AIModel):
|
||||
:return:
|
||||
"""
|
||||
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -4,7 +4,6 @@ from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class ModerationModel(AIModel):
|
||||
@ -30,6 +29,8 @@ class ModerationModel(AIModel):
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class RerankModel(AIModel):
|
||||
@ -34,6 +33,8 @@ class RerankModel(AIModel):
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -4,7 +4,6 @@ from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class Speech2TextModel(AIModel):
|
||||
@ -28,6 +27,8 @@ class Speech2TextModel(AIModel):
|
||||
:return: text for given audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -4,7 +4,6 @@ from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class TextEmbeddingModel(AIModel):
|
||||
@ -35,6 +34,8 @@ class TextEmbeddingModel(AIModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
try:
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
@ -59,6 +60,8 @@ class TextEmbeddingModel(AIModel):
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -5,7 +5,6 @@ from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -41,6 +40,8 @@ class TTSModel(AIModel):
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -64,6 +65,8 @@ class TTSModel(AIModel):
|
||||
:param credentials: The credentials required to access the TTS model.
|
||||
:return: A list of voices supported by the TTS model.
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -15,16 +15,16 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
def __init__(self, tenant_id: str):
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_model_manager = PluginModelClient()
|
||||
|
||||
@ -38,7 +38,7 @@ class ModelProviderFactory:
|
||||
plugin_providers = self.get_plugin_model_providers()
|
||||
return [provider.declaration for provider in plugin_providers]
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
||||
"""
|
||||
Get all plugin model providers
|
||||
:return: list of plugin model providers
|
||||
@ -76,7 +76,7 @@ class ModelProviderFactory:
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
return plugin_model_provider_entity.declaration
|
||||
|
||||
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
|
||||
"""
|
||||
Get plugin model provider
|
||||
:param provider: provider name
|
||||
@ -331,6 +331,8 @@ class ModelProviderFactory:
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
|
||||
# get icon bytes from plugin asset manager
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
|
||||
plugin_asset_manager = PluginAssetManager()
|
||||
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
@ -340,5 +342,6 @@ class ModelProviderFactory:
|
||||
:param provider: provider name
|
||||
:return: plugin id and provider name
|
||||
"""
|
||||
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
|
||||
@ -54,13 +54,10 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes import NodeType
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
||||
@ -28,8 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@ -22,8 +22,7 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ import queue
|
||||
import threading
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from cachetools import LRUCache
|
||||
@ -31,13 +31,15 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -407,7 +409,7 @@ class TraceTask:
|
||||
self,
|
||||
trace_type: Any,
|
||||
message_id: str | None = None,
|
||||
workflow_execution: WorkflowExecution | None = None,
|
||||
workflow_execution: Optional["WorkflowExecution"] = None,
|
||||
conversation_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
timer: Any | None = None,
|
||||
|
||||
@ -23,8 +23,7 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@ -167,7 +167,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
call_depth=1,
|
||||
workflow_thread_pool_id=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
|
||||
@ -6,7 +6,6 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.workflow.nodes.base.entities import NumberType
|
||||
|
||||
|
||||
class PluginParameterOption(BaseModel):
|
||||
@ -153,7 +152,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
raise ValueError("The tools selector must be a list.")
|
||||
return value
|
||||
case PluginParameterType.ANY:
|
||||
if value and not isinstance(value, str | dict | list | NumberType):
|
||||
if value and not isinstance(value, str | dict | list | int | float):
|
||||
raise ValueError("The var selector must be a string, dictionary, list or number.")
|
||||
return value
|
||||
case PluginParameterType.ARRAY:
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
import datetime
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
@ -156,55 +154,6 @@ class PluginEntity(PluginInstallation):
|
||||
return self
|
||||
|
||||
|
||||
class GenericProviderID:
|
||||
organization: str
|
||||
plugin_name: str
|
||||
provider_name: str
|
||||
is_hardcoded: bool
|
||||
|
||||
def to_string(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
||||
|
||||
def __init__(self, value: str, is_hardcoded: bool = False):
|
||||
if not value:
|
||||
raise NotFound("plugin not found, please add plugin")
|
||||
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
|
||||
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
||||
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
|
||||
if re.match(r"^[a-z0-9_-]+$", value):
|
||||
value = f"langgenius/{value}/{value}"
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin id {value}")
|
||||
|
||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||
self.is_hardcoded = is_hardcoded
|
||||
|
||||
def is_langgenius(self) -> bool:
|
||||
return self.organization == "langgenius"
|
||||
|
||||
@property
|
||||
def plugin_id(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}"
|
||||
|
||||
|
||||
class ModelProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False):
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius" and self.provider_name == "google":
|
||||
self.plugin_name = "gemini"
|
||||
|
||||
|
||||
class ToolProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False):
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius":
|
||||
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(StrEnum):
|
||||
Github = PluginInstallationSource.Github
|
||||
|
||||
@ -2,13 +2,13 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginAgentProviderEntity,
|
||||
)
|
||||
from core.plugin.entities.request import PluginInvokeContext
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class PluginAgentClient(BasePluginClient):
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class DynamicSelectClient(BasePluginClient):
|
||||
|
||||
@ -2,7 +2,6 @@ from collections.abc import Sequence
|
||||
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
MissingPluginDependency,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
@ -16,6 +15,7 @@ from core.plugin.entities.plugin_daemon import (
|
||||
PluginListResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class PluginInstaller(BasePluginClient):
|
||||
|
||||
@ -3,11 +3,11 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||
from models.provider_ids import GenericProviderID, ToolProviderID
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginClient):
|
||||
|
||||
@ -36,7 +36,6 @@ from core.model_runtime.entities.provider_entities import (
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -51,6 +50,7 @@ from models.provider import (
|
||||
TenantDefaultModel,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
@ -13,6 +13,9 @@ from core.rag.splitter.fixed_text_splitter import (
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
"""Interface for extract files."""
|
||||
@ -51,7 +54,7 @@ class BaseIndexProcessor(ABC):
|
||||
max_tokens: int,
|
||||
chunk_overlap: int,
|
||||
separator: str,
|
||||
embedding_model_instance: ModelInstance | None,
|
||||
embedding_model_instance: Optional["ModelInstance"],
|
||||
) -> TextSplitter:
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
|
||||
@ -9,11 +9,8 @@ from typing import Union
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution import (
|
||||
WorkflowExecution,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
@ -203,5 +200,4 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
session.commit()
|
||||
|
||||
# Update the in-memory cache for faster subsequent lookups
|
||||
logger.debug("Updating cache for execution_id: %s", db_model.id)
|
||||
self._execution_cache[db_model.id] = db_model
|
||||
|
||||
@ -15,12 +15,8 @@ from sqlalchemy.orm import sessionmaker
|
||||
from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
@ -245,7 +241,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
# Update the in-memory cache after successful save
|
||||
if db_model.node_execution_id:
|
||||
logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id)
|
||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||
|
||||
except Exception:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@ -16,10 +17,10 @@ class ToolApiEntity(BaseModel):
|
||||
description: I18nObject
|
||||
parameters: list[ToolParameter] | None = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: dict | None = None
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]]
|
||||
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow", "mcp"] | None
|
||||
|
||||
|
||||
class ToolProviderApiEntity(BaseModel):
|
||||
@ -27,17 +28,17 @@ class ToolProviderApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
icon_dark: str | dict | None = Field(default=None, description="The dark icon of the tool")
|
||||
icon: str | Mapping[str, str]
|
||||
icon_dark: str | Mapping[str, str] = ""
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
masked_credentials: dict | None = None
|
||||
original_credentials: dict | None = None
|
||||
masked_credentials: Mapping[str, object] = Field(default_factory=dict)
|
||||
original_credentials: Mapping[str, object] = Field(default_factory=dict)
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: str | None = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool")
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list)
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list[ToolApiEntity])
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
# MCP
|
||||
server_url: str | None = Field(default="", description="The server url of the tool")
|
||||
@ -105,7 +106,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
|
||||
is_default: bool = Field(
|
||||
default=False, description="Whether the credential is the default credential for the provider in the workspace"
|
||||
)
|
||||
credentials: dict = Field(description="The credentials of the provider")
|
||||
credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict)
|
||||
|
||||
|
||||
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
||||
|
||||
@ -186,7 +186,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
error: str | None = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Mapping[str, Any] | None = Field(default=None, description="The metadata of the log")
|
||||
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
|
||||
|
||||
class RetrieverResourceMessage(BaseModel):
|
||||
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
@ -362,9 +362,9 @@ class ToolDescription(BaseModel):
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
identity: ToolIdentity
|
||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
parameters: list[ToolParameter] = Field(default_factory=list[ToolParameter])
|
||||
description: ToolDescription | None = None
|
||||
output_schema: dict | None = None
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
|
||||
|
||||
# pydantic configs
|
||||
@ -377,21 +377,23 @@ class ToolEntity(BaseModel):
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
|
||||
)
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: str | None = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list[ProviderConfig])
|
||||
oauth_schema: OAuthSchema | None = None
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
tools: list[ToolEntity] = Field(default_factory=list)
|
||||
tools: list[ToolEntity] = Field(default_factory=list[ToolEntity])
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
|
||||
@ -72,7 +72,6 @@ class MCPToolProviderController(ToolProviderController):
|
||||
),
|
||||
llm=remote_mcp_tool.description or "",
|
||||
),
|
||||
output_schema=None,
|
||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
|
||||
@ -152,7 +152,6 @@ class ToolEngine:
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
thread_pool_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
@ -166,7 +165,6 @@ class ToolEngine:
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
tool.thread_pool_id = thread_pool_id
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
||||
@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import TypeAdapter
|
||||
@ -14,32 +14,17 @@ from sqlalchemy.orm import Session
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
@ -55,14 +40,28 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -117,6 +116,7 @@ class ToolManager:
|
||||
get the plugin provider
|
||||
"""
|
||||
# check if context is set
|
||||
|
||||
try:
|
||||
contexts.plugin_tool_providers.get()
|
||||
except LookupError:
|
||||
@ -172,6 +172,7 @@ class ToolManager:
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
@ -213,16 +214,16 @@ class ToolManager:
|
||||
# fallback to the default provider
|
||||
if builtin_provider is None:
|
||||
# use the default provider
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
with Session(db.engine) as session:
|
||||
builtin_provider = session.scalar(
|
||||
sa.select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
@ -263,6 +264,7 @@ class ToolManager:
|
||||
# check if the credentials is expired
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
# refresh the credentials
|
||||
@ -270,6 +272,7 @@ class ToolManager:
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
# refresh the credentials
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
@ -358,7 +361,7 @@ class ToolManager:
|
||||
app_id: str,
|
||||
agent_tool: AgentToolEntity,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: VariablePool | None = None,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
@ -400,7 +403,7 @@ class ToolManager:
|
||||
node_id: str,
|
||||
workflow_tool: "ToolEntity",
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: VariablePool | None = None,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
@ -516,6 +519,7 @@ class ToolManager:
|
||||
"""
|
||||
list all the plugin providers
|
||||
"""
|
||||
|
||||
manager = PluginToolManager()
|
||||
provider_entities = manager.fetch_tool_providers(tenant_id)
|
||||
return [
|
||||
@ -882,7 +886,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
@ -893,13 +897,13 @@ class ToolManager:
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
icon: dict = json.loads(workflow_provider.icon)
|
||||
icon = json.loads(workflow_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
@ -910,13 +914,13 @@ class ToolManager:
|
||||
if api_provider is None:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
|
||||
icon: dict = json.loads(api_provider.icon)
|
||||
icon = json.loads(api_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
||||
try:
|
||||
mcp_provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
@ -937,7 +941,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
) -> str | Mapping[str, str]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
@ -962,11 +966,10 @@ class ToolManager:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
raise ValueError(f"plugin provider {provider_id} not found")
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
@ -977,7 +980,7 @@ class ToolManager:
|
||||
def _convert_tool_parameters_type(
|
||||
cls,
|
||||
parameters: list[ToolParameter],
|
||||
variable_pool: VariablePool | None,
|
||||
variable_pool: Optional["VariablePool"],
|
||||
tool_configurations: dict[str, Any],
|
||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@ -39,14 +39,12 @@ class WorkflowTool(Tool):
|
||||
entity: ToolEntity,
|
||||
runtime: ToolRuntime,
|
||||
label: str = "Workflow",
|
||||
thread_pool_id: str | None = None,
|
||||
):
|
||||
self.workflow_app_id = workflow_app_id
|
||||
self.workflow_as_tool_id = workflow_as_tool_id
|
||||
self.version = version
|
||||
self.workflow_entities = workflow_entities
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.label = label
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
@ -90,7 +88,6 @@ class WorkflowTool(Tool):
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
workflow_thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
data = result.get("data", {})
|
||||
|
||||
@ -130,7 +130,7 @@ class ArraySegment(Segment):
|
||||
def markdown(self) -> str:
|
||||
items = []
|
||||
for item in self.value:
|
||||
items.append(str(item))
|
||||
items.append(f"- {item}")
|
||||
return "\n".join(items)
|
||||
|
||||
|
||||
|
||||
132
api/core/workflow/README.md
Normal file
132
api/core/workflow/README.md
Normal file
@ -0,0 +1,132 @@
|
||||
# Workflow
|
||||
|
||||
## Project Overview
|
||||
|
||||
This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
The graph engine follows a layered architecture with strict dependency rules:
|
||||
|
||||
1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution
|
||||
|
||||
- **Manager** - External control interface for stop/pause/resume commands
|
||||
- **Worker** - Node execution runtime
|
||||
- **Command Processing** - Handles control commands (abort, pause, resume)
|
||||
- **Event Management** - Event propagation and layer notifications
|
||||
- **Graph Traversal** - Edge processing and skip propagation
|
||||
- **Response Coordinator** - Path tracking and session management
|
||||
- **Layers** - Pluggable middleware (debug logging, execution limits)
|
||||
- **Command Channels** - Communication channels (InMemory, Redis)
|
||||
|
||||
1. **Graph** (`graph/`) - Graph structure and runtime state
|
||||
|
||||
- **Graph Template** - Workflow definition
|
||||
- **Edge** - Node connections with conditions
|
||||
- **Runtime State Protocol** - State management interface
|
||||
|
||||
1. **Nodes** (`nodes/`) - Node implementations
|
||||
|
||||
- **Base** - Abstract node classes and variable parsing
|
||||
- **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc.
|
||||
|
||||
1. **Events** (`node_events/`) - Event system
|
||||
|
||||
- **Base** - Event protocols
|
||||
- **Node Events** - Node lifecycle events
|
||||
|
||||
1. **Entities** (`entities/`) - Domain models
|
||||
|
||||
- **Variable Pool** - Variable storage
|
||||
- **Graph Init Params** - Initialization configuration
|
||||
|
||||
## Key Design Patterns
|
||||
|
||||
### Command Channel Pattern
|
||||
|
||||
External workflow control via Redis or in-memory channels:
|
||||
|
||||
```python
|
||||
# Send stop command to running workflow
|
||||
channel = RedisChannel(redis_client, f"workflow:{task_id}:commands")
|
||||
channel.send_command(AbortCommand(reason="User requested"))
|
||||
```
|
||||
|
||||
### Layer System
|
||||
|
||||
Extensible middleware for cross-cutting concerns:
|
||||
|
||||
```python
|
||||
engine = GraphEngine(graph)
|
||||
engine.add_layer(DebugLoggingLayer(level="INFO"))
|
||||
engine.add_layer(ExecutionLimitsLayer(max_nodes=100))
|
||||
```
|
||||
|
||||
### Event-Driven Architecture
|
||||
|
||||
All node executions emit events for monitoring and integration:
|
||||
|
||||
- `NodeRunStartedEvent` - Node execution begins
|
||||
- `NodeRunSucceededEvent` - Node completes successfully
|
||||
- `NodeRunFailedEvent` - Node encounters error
|
||||
- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle
|
||||
|
||||
### Variable Pool
|
||||
|
||||
Centralized variable storage with namespace isolation:
|
||||
|
||||
```python
|
||||
# Variables scoped by node_id
|
||||
pool.add(["node1", "output"], value)
|
||||
result = pool.get(["node1", "output"])
|
||||
```
|
||||
|
||||
## Import Architecture Rules
|
||||
|
||||
The codebase enforces strict layering via import-linter:
|
||||
|
||||
1. **Workflow Layers** (top to bottom):
|
||||
|
||||
- graph_engine → graph_events → graph → nodes → node_events → entities
|
||||
|
||||
1. **Graph Engine Internal Layers**:
|
||||
|
||||
- orchestration → command_processing → event_management → graph_traversal → domain
|
||||
|
||||
1. **Domain Isolation**:
|
||||
|
||||
- Domain models cannot import from infrastructure layers
|
||||
|
||||
1. **Command Channel Independence**:
|
||||
|
||||
- InMemory and Redis channels must remain independent
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Adding a New Node Type
|
||||
|
||||
1. Create node class in `nodes/<node_type>/`
|
||||
1. Inherit from `BaseNode` or appropriate base class
|
||||
1. Implement `_run()` method
|
||||
1. Register in `nodes/node_mapping.py`
|
||||
1. Add tests in `tests/unit_tests/core/workflow/nodes/`
|
||||
|
||||
### Implementing a Custom Layer
|
||||
|
||||
1. Create class inheriting from `Layer` base
|
||||
1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()`
|
||||
1. Add to engine via `engine.add_layer()`
|
||||
|
||||
### Debugging Workflow Execution
|
||||
|
||||
Enable debug logging layer:
|
||||
|
||||
```python
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="DEBUG",
|
||||
include_inputs=True,
|
||||
include_outputs=True
|
||||
)
|
||||
```
|
||||
@ -1,7 +0,0 @@
|
||||
from .base_workflow_callback import WorkflowCallback
|
||||
from .workflow_logging_callback import WorkflowLoggingCallback
|
||||
|
||||
__all__ = [
|
||||
"WorkflowCallback",
|
||||
"WorkflowLoggingCallback",
|
||||
]
|
||||
@ -1,12 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
||||
|
||||
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Published event
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -1,259 +0,0 @@
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .base_workflow_callback import WorkflowCallback
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
def __init__(self):
|
||||
self.current_node_id: str | None = None
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(event=event)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(event=event)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(event=event)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(event=event)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(event=event)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(event=event)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(event=event)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(event=event)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(event=event)
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
self.on_workflow_loop_started(event=event)
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
self.on_workflow_loop_next(event=event)
|
||||
elif isinstance(event, LoopRunSucceededEvent | LoopRunFailedEvent):
|
||||
self.on_workflow_loop_completed(event=event)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
|
||||
|
||||
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent):
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="yellow")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="yellow")
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent):
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color="green")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="green")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="green")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="green")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
||||
color="green",
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunFailedEvent]", color="red")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="red")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="red")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="red")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color="red")
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent):
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text("\n[NodeRunStreamChunkEvent]")
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
|
||||
node_run_result = route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
|
||||
)
|
||||
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent):
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
|
||||
if event.in_loop_id:
|
||||
self.print_text(f"Loop ID: {event.in_loop_id}", color="blue")
|
||||
|
||||
def on_workflow_parallel_completed(self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = "blue"
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = "red"
|
||||
|
||||
self.print_text(
|
||||
"\n[ParallelBranchRunSucceededEvent]"
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent)
|
||||
else "\n[ParallelBranchRunFailedEvent]",
|
||||
color=color,
|
||||
)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||
if event.in_loop_id:
|
||||
self.print_text(f"Loop ID: {event.in_loop_id}", color=color)
|
||||
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(self, event: IterationRunStartedEvent):
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[IterationRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_iteration_next(self, event: IterationRunNextEvent):
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[IterationRunNextEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
self.print_text(f"Iteration Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text(
|
||||
"\n[IterationRunSucceededEvent]"
|
||||
if isinstance(event, IterationRunSucceededEvent)
|
||||
else "\n[IterationRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_started(self, event: LoopRunStartedEvent):
|
||||
"""
|
||||
Publish loop started
|
||||
"""
|
||||
self.print_text("\n[LoopRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_next(self, event: LoopRunNextEvent):
|
||||
"""
|
||||
Publish loop next
|
||||
"""
|
||||
self.print_text("\n[LoopRunNextEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
self.print_text(f"Loop Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent):
|
||||
"""
|
||||
Publish loop completed
|
||||
"""
|
||||
self.print_text(
|
||||
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def print_text(self, text: str, color: str | None = None, end: str = "\n"):
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(f"{text_to_print}", end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
@ -0,0 +1,18 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .run_condition import RunCondition
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"RunCondition",
|
||||
"VariablePool",
|
||||
"VariableValue",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
8
api/core/workflow/entities/agent.py
Normal file
8
api/core/workflow/entities/agent.py
Normal file
@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
"""Agent node strategy initialization data."""
|
||||
|
||||
name: str
|
||||
icon: str | None = None
|
||||
@ -3,19 +3,18 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
user_from: str = Field(
|
||||
..., description="user from, account or end-user"
|
||||
) # Should be UserFrom enum: 'account' | 'end-user'
|
||||
invoke_from: str = Field(
|
||||
..., description="invoke from, service-api, web-app, explore or debugger"
|
||||
) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger'
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
160
api/core/workflow/entities/graph_runtime_state.py
Normal file
160
api/core/workflow/entities/graph_runtime_state.py
Normal file
@ -0,0 +1,160 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
# Private attributes to prevent direct modification
|
||||
_variable_pool: VariablePool = PrivateAttr()
|
||||
_start_at: float = PrivateAttr()
|
||||
_total_tokens: int = PrivateAttr(default=0)
|
||||
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
||||
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
|
||||
_node_run_steps: int = PrivateAttr(default=0)
|
||||
_ready_queue_json: str = PrivateAttr()
|
||||
_graph_execution_json: str = PrivateAttr()
|
||||
_response_coordinator_json: str = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
start_at: float,
|
||||
total_tokens: int = 0,
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, object] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_json: str = "",
|
||||
graph_execution_json: str = "",
|
||||
response_coordinator_json: str = "",
|
||||
**kwargs: object,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Initialize private attributes with validation
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
self._start_at = start_at
|
||||
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
if llm_usage is None:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
self._llm_usage = llm_usage
|
||||
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
self._outputs = deepcopy(outputs)
|
||||
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
self._ready_queue_json = ready_queue_json
|
||||
self._graph_execution_json = graph_execution_json
|
||||
self._response_coordinator_json = response_coordinator_json
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
"""Get the variable pool."""
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time."""
|
||||
return self._start_at
|
||||
|
||||
@start_at.setter
|
||||
def start_at(self, value: float) -> None:
|
||||
"""Set the start time."""
|
||||
self._start_at = value
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count."""
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: int):
|
||||
"""Set the total tokens count."""
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = value
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get the LLM usage info."""
|
||||
# Return a copy to prevent external modification
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
@llm_usage.setter
|
||||
def llm_usage(self, value: LLMUsage):
|
||||
"""Set the LLM usage info."""
|
||||
self._llm_usage = value.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
"""Get a copy of the outputs dictionary."""
|
||||
return deepcopy(self._outputs)
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, value: dict[str, object]) -> None:
|
||||
"""Set the outputs dictionary."""
|
||||
self._outputs = deepcopy(value)
|
||||
|
||||
def set_output(self, key: str, value: object) -> None:
|
||||
"""Set a single output value."""
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
"""Get a single output value."""
|
||||
return deepcopy(self._outputs.get(key, default))
|
||||
|
||||
def update_outputs(self, updates: dict[str, object]) -> None:
|
||||
"""Update multiple output values."""
|
||||
for key, value in updates.items():
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count."""
|
||||
return self._node_run_steps
|
||||
|
||||
@node_run_steps.setter
|
||||
def node_run_steps(self, value: int) -> None:
|
||||
"""Set the node run steps count."""
|
||||
if value < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = value
|
||||
|
||||
def increment_node_run_steps(self) -> None:
|
||||
"""Increment the node run steps by 1."""
|
||||
self._node_run_steps += 1
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
"""Add tokens to the total count."""
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
@property
|
||||
def ready_queue_json(self) -> str:
|
||||
"""Get a copy of the ready queue state."""
|
||||
return self._ready_queue_json
|
||||
|
||||
@property
|
||||
def graph_execution_json(self) -> str:
|
||||
"""Get a copy of the serialized graph execution state."""
|
||||
return self._graph_execution_json
|
||||
|
||||
@property
|
||||
def response_coordinator_json(self) -> str:
|
||||
"""Get a copy of the serialized response coordinator state."""
|
||||
return self._response_coordinator_json
|
||||
@ -1,34 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
"""
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Mapping[str, Any] | None = None # node inputs
|
||||
process_data: Mapping[str, Any] | None = None # process data
|
||||
outputs: Mapping[str, Any] | None = None # node outputs
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # node metadata
|
||||
llm_usage: LLMUsage | None = None # llm usage
|
||||
|
||||
edge_source_handle: str | None = None # source handle id of node with multiple branches
|
||||
|
||||
error: str | None = None # error message if status is failed
|
||||
error_type: str | None = None # error type if status is failed
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
name: str
|
||||
icon: str | None = None
|
||||
@ -1,12 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
Variable Selector.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_selector: Sequence[str]
|
||||
@ -14,7 +14,7 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_V
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from factories import variable_factory
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
@ -40,11 +40,11 @@ class VariablePool(BaseModel):
|
||||
)
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list,
|
||||
default_factory=list[VariableUnion],
|
||||
)
|
||||
conversation_variables: Sequence[VariableUnion] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list,
|
||||
default_factory=list[VariableUnion],
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /):
|
||||
@ -191,7 +191,7 @@ class VariablePool(BaseModel):
|
||||
|
||||
def convert_template(self, template: str, /):
|
||||
parts = VARIABLE_PATTERN.split(template)
|
||||
segments = []
|
||||
segments: list[Segment] = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and (variable := self.get(part.split("."))):
|
||||
segments.append(variable)
|
||||
|
||||
@ -7,31 +7,14 @@ implementation details like tenant_id, app_id, etc.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum for domain layer
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
STOPPED = "stopped"
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow execution based on WorkflowRun but without
|
||||
|
||||
@ -8,49 +8,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
"""
|
||||
Node Execution Status Enum.
|
||||
"""
|
||||
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
RETRY = "retry"
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
|
||||
@ -1,4 +1,12 @@
|
||||
from enum import StrEnum
|
||||
from enum import Enum, StrEnum
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
"""State of a node or edge during workflow execution."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
TAKEN = "taken"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class SystemVariableKey(StrEnum):
|
||||
@ -14,3 +22,104 @@ class SystemVariableKey(StrEnum):
|
||||
APP_ID = "app_id"
|
||||
WORKFLOW_ID = "workflow_id"
|
||||
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
LOOP_START = "loop-start"
|
||||
LOOP_END = "loop-end"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
"""Node execution type classification."""
|
||||
|
||||
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
|
||||
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
|
||||
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
|
||||
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
|
||||
ROOT = "root" # Nodes that can serve as execution entry points
|
||||
|
||||
|
||||
class ErrorStrategy(StrEnum):
|
||||
FAIL_BRANCH = "fail-branch"
|
||||
DEFAULT_VALUE = "default-value"
|
||||
|
||||
|
||||
class FailBranchSourceHandle(StrEnum):
|
||||
FAILED = "fail-branch"
|
||||
SUCCESS = "success-branch"
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum for domain layer
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
STOPPED = "stopped"
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
PENDING = "pending" # Node is scheduled but not yet executing
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
STOPPED = "stopped"
|
||||
PAUSED = "paused"
|
||||
|
||||
# Legacy statuses - kept for backward compatibility
|
||||
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node: BaseNode, err_msg: str):
|
||||
def __init__(self, node: Node, err_msg: str):
|
||||
self.node = node
|
||||
self.error = err_msg
|
||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||
|
||||
16
api/core/workflow/graph/__init__.py
Normal file
16
api/core/workflow/graph/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
from .edge import Edge
|
||||
from .graph import Graph, NodeFactory
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .graph_template import GraphTemplate
|
||||
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
|
||||
__all__ = [
|
||||
"Edge",
|
||||
"Graph",
|
||||
"GraphTemplate",
|
||||
"NodeFactory",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
"ReadOnlyVariablePool",
|
||||
"ReadOnlyVariablePoolWrapper",
|
||||
]
|
||||
15
api/core/workflow/graph/edge.py
Normal file
15
api/core/workflow/graph/edge.py
Normal file
@ -0,0 +1,15 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
"""Edge connecting two nodes in a workflow graph."""
|
||||
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
tail: str = "" # tail node id (source)
|
||||
head: str = "" # head node id (target)
|
||||
source_handle: str = "source" # source handle for conditional branching
|
||||
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state
|
||||
346
api/core/workflow/graph/graph.py
Normal file
346
api/core/workflow/graph/graph.py
Normal file
@ -0,0 +1,346 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .edge import Edge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeFactory(Protocol):
|
||||
"""
|
||||
Protocol for creating Node instances from node data dictionaries.
|
||||
|
||||
This protocol decouples the Graph class from specific node mapping implementations,
|
||||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@final
|
||||
class Graph:
|
||||
"""Graph representation with nodes and edges for workflow execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nodes: dict[str, Node] | None = None,
|
||||
edges: dict[str, Edge] | None = None,
|
||||
in_edges: dict[str, list[str]] | None = None,
|
||||
out_edges: dict[str, list[str]] | None = None,
|
||||
root_node: Node,
|
||||
):
|
||||
"""
|
||||
Initialize Graph instance.
|
||||
|
||||
:param nodes: graph nodes mapping (node id: node object)
|
||||
:param edges: graph edges mapping (edge id: edge object)
|
||||
:param in_edges: incoming edges mapping (node id: list of edge ids)
|
||||
:param out_edges: outgoing edges mapping (node id: list of edge ids)
|
||||
:param root_node: root node object
|
||||
"""
|
||||
self.nodes = nodes or {}
|
||||
self.edges = edges or {}
|
||||
self.in_edges = in_edges or {}
|
||||
self.out_edges = out_edges or {}
|
||||
self.root_node = root_node
|
||||
|
||||
@classmethod
|
||||
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
|
||||
"""
|
||||
Parse node configurations and build a mapping of node IDs to configs.
|
||||
|
||||
:param node_configs: list of node configuration dictionaries
|
||||
:return: mapping of node ID to node config
|
||||
"""
|
||||
node_configs_map: dict[str, dict[str, object]] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
continue
|
||||
|
||||
node_configs_map[node_id] = node_config
|
||||
|
||||
return node_configs_map
|
||||
|
||||
@classmethod
|
||||
def _find_root_node_id(
|
||||
cls,
|
||||
node_configs_map: Mapping[str, Mapping[str, object]],
|
||||
edge_configs: Sequence[Mapping[str, object]],
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find the root node ID if not specified.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param edge_configs: list of edge configurations
|
||||
:param root_node_id: explicitly specified root node ID
|
||||
:return: determined root node ID
|
||||
"""
|
||||
if root_node_id:
|
||||
if root_node_id not in node_configs_map:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
return root_node_id
|
||||
|
||||
# Find nodes with no incoming edges
|
||||
nodes_with_incoming: set[str] = set()
|
||||
for edge_config in edge_configs:
|
||||
target = edge_config.get("target")
|
||||
if isinstance(target, str):
|
||||
nodes_with_incoming.add(target)
|
||||
|
||||
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
|
||||
|
||||
# Prefer START node if available
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid].get("data")
|
||||
if not is_str_dict(node_data):
|
||||
continue
|
||||
node_type = node_data.get("type")
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if node_type == NodeType.START:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
|
||||
|
||||
if not root_node_id:
|
||||
raise ValueError("Unable to determine root node ID")
|
||||
|
||||
return root_node_id
|
||||
|
||||
@classmethod
|
||||
def _build_edges(
|
||||
cls, edge_configs: list[dict[str, object]]
|
||||
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
|
||||
"""
|
||||
Build edge objects and mappings from edge configurations.
|
||||
|
||||
:param edge_configs: list of edge configurations
|
||||
:return: tuple of (edges dict, in_edges dict, out_edges dict)
|
||||
"""
|
||||
edges: dict[str, Edge] = {}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
edge_counter = 0
|
||||
for edge_config in edge_configs:
|
||||
source = edge_config.get("source")
|
||||
target = edge_config.get("target")
|
||||
|
||||
if not is_str(source) or not is_str(target):
|
||||
continue
|
||||
|
||||
# Create edge
|
||||
edge_id = f"edge_{edge_counter}"
|
||||
edge_counter += 1
|
||||
|
||||
source_handle = edge_config.get("sourceHandle", "source")
|
||||
if not is_str(source_handle):
|
||||
continue
|
||||
|
||||
edge = Edge(
|
||||
id=edge_id,
|
||||
tail=source,
|
||||
head=target,
|
||||
source_handle=source_handle,
|
||||
)
|
||||
|
||||
edges[edge_id] = edge
|
||||
out_edges[source].append(edge_id)
|
||||
in_edges[target].append(edge_id)
|
||||
|
||||
return edges, dict(in_edges), dict(out_edges)
|
||||
|
||||
@classmethod
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, object]],
|
||||
node_factory: "NodeFactory",
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
Create node instances from configurations using the node factory.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param node_factory: factory for creating node instances
|
||||
:return: mapping of node ID to node instance
|
||||
"""
|
||||
nodes: dict[str, Node] = {}
|
||||
|
||||
for node_id, node_config in node_configs_map.items():
|
||||
try:
|
||||
node_instance = node_factory.create_node(node_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to create node instance for node_id %s", node_id)
|
||||
raise
|
||||
nodes[node_id] = node_instance
|
||||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
nodes: dict[str, Node],
|
||||
edges: dict[str, Edge],
|
||||
in_edges: dict[str, list[str]],
|
||||
out_edges: dict[str, list[str]],
|
||||
active_root_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Mark nodes and edges from inactive root branches as skipped.
|
||||
|
||||
Algorithm:
|
||||
1. Mark inactive root nodes as skipped
|
||||
2. For skipped nodes, mark all their outgoing edges as skipped
|
||||
3. For each edge marked as skipped, check its target node:
|
||||
- If ALL incoming edges are skipped, mark the node as skipped
|
||||
- Otherwise, leave the node state unchanged
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
:param edges: mapping of edge ID to edge instance
|
||||
:param in_edges: mapping of node ID to incoming edge IDs
|
||||
:param out_edges: mapping of node ID to outgoing edge IDs
|
||||
:param active_root_id: ID of the active root node
|
||||
"""
|
||||
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
|
||||
top_level_roots: list[str] = [
|
||||
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
|
||||
]
|
||||
|
||||
# If there's only one root or the active root is not a top-level root, no marking needed
|
||||
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
|
||||
return
|
||||
|
||||
# Mark inactive root nodes as skipped
|
||||
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
|
||||
for root_id in inactive_roots:
|
||||
if root_id in nodes:
|
||||
nodes[root_id].state = NodeState.SKIPPED
|
||||
|
||||
# Recursively mark downstream nodes and edges
|
||||
def mark_downstream(node_id: str) -> None:
|
||||
"""Recursively mark downstream nodes and edges as skipped."""
|
||||
if nodes[node_id].state != NodeState.SKIPPED:
|
||||
return
|
||||
# If this node is skipped, mark all its outgoing edges as skipped
|
||||
out_edge_ids = out_edges.get(node_id, [])
|
||||
for edge_id in out_edge_ids:
|
||||
edge = edges[edge_id]
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
# Check the target node of this edge
|
||||
target_node = nodes[edge.head]
|
||||
in_edge_ids = in_edges.get(target_node.id, [])
|
||||
in_edge_states = [edges[eid].state for eid in in_edge_ids]
|
||||
|
||||
# If all incoming edges are skipped, mark the node as skipped
|
||||
if all(state == NodeState.SKIPPED for state in in_edge_states):
|
||||
target_node.state = NodeState.SKIPPED
|
||||
# Recursively process downstream nodes
|
||||
mark_downstream(target_node.id)
|
||||
|
||||
# Process each inactive root and its downstream nodes
|
||||
for root_id in inactive_roots:
|
||||
mark_downstream(root_id)
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, object],
|
||||
node_factory: "NodeFactory",
|
||||
root_node_id: str | None = None,
|
||||
) -> "Graph":
|
||||
"""
|
||||
Initialize graph
|
||||
|
||||
:param graph_config: graph config containing nodes and edges
|
||||
:param node_factory: factory for creating node instances from config data
|
||||
:param root_node_id: root node id
|
||||
:return: graph instance
|
||||
"""
|
||||
# Parse configs
|
||||
edge_configs = graph_config.get("edges", [])
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
# Find root node
|
||||
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
|
||||
|
||||
# Build edges
|
||||
edges, in_edges, out_edges = cls._build_edges(edge_configs)
|
||||
|
||||
# Create node instances
|
||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
# Mark inactive root branches as skipped
|
||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=in_edges,
|
||||
out_edges=out_edges,
|
||||
root_node=root_node,
|
||||
)
|
||||
|
||||
@property
|
||||
def node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get list of node IDs (compatibility property for existing code)
|
||||
|
||||
:return: list of node IDs
|
||||
"""
|
||||
return list(self.nodes.keys())
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all outgoing edges from a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of outgoing edges
|
||||
"""
|
||||
edge_ids = self.out_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
def get_incoming_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all incoming edges to a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of incoming edges
|
||||
"""
|
||||
edge_ids = self.in_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
61
api/core/workflow/graph/graph_runtime_state_protocol.py
Normal file
61
api/core/workflow/graph/graph_runtime_state_protocol.py
Normal file
@ -0,0 +1,61 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
"""Read-only interface for VariablePool."""
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (read-only)."""
|
||||
...
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""
|
||||
Read-only view of GraphRuntimeState for layers.
|
||||
|
||||
This protocol defines a read-only interface that prevents layers from
|
||||
modifying the graph runtime state while still allowing observation.
|
||||
All methods return defensive copies to ensure immutability.
|
||||
"""
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
"""Get read-only access to the variable pool."""
|
||||
...
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
...
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
...
|
||||
20
api/core/workflow/graph/graph_template.py
Normal file
20
api/core/workflow/graph/graph_template.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GraphTemplate(BaseModel):
|
||||
"""
|
||||
Graph Template for container nodes and subgraph expansion
|
||||
|
||||
According to GraphEngine V2 spec, GraphTemplate contains:
|
||||
- nodes: mapping of node definitions
|
||||
- edges: mapping of edge definitions
|
||||
- root_ids: list of root node IDs
|
||||
- output_selectors: list of output selectors for the template
|
||||
"""
|
||||
|
||||
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
|
||||
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
|
||||
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
|
||||
output_selectors: list[str] = Field(default_factory=list, description="output selectors")
|
||||
77
api/core/workflow/graph/read_only_state_wrapper.py
Normal file
77
api/core/workflow/graph/read_only_state_wrapper.py
Normal file
@ -0,0 +1,77 @@
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadOnlyVariablePoolWrapper:
|
||||
"""Wrapper that provides read-only access to VariablePool."""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool):
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (returns a defensive copy)."""
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (returns defensive copies)."""
|
||||
variables: dict[str, object] = {}
|
||||
if node_id in self._variable_pool.variable_dictionary:
|
||||
for key, var in self._variable_pool.variable_dictionary[node_id].items():
|
||||
# Variables have a value property that contains the actual data
|
||||
variables[key] = deepcopy(var.value)
|
||||
return variables
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeStateWrapper:
|
||||
"""
|
||||
Wrapper that provides read-only access to GraphRuntimeState.
|
||||
|
||||
This wrapper ensures that layers can observe the state without
|
||||
modifying it. All returned values are defensive copies.
|
||||
"""
|
||||
|
||||
def __init__(self, state: GraphRuntimeState):
|
||||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
"""Get read-only access to the variable pool."""
|
||||
return self._variable_pool_wrapper
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
return self._state.start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
return self._state.total_tokens
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
# Return a copy to prevent modification
|
||||
return self._state.llm_usage.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
return deepcopy(self._state.outputs)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
return self._state.node_run_steps
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
return self._state.get_output(key, default)
|
||||
@ -1,4 +1,3 @@
|
||||
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
__all__ = ["GraphEngine"]
|
||||
|
||||
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# Command Channels
|
||||
|
||||
Channel implementations for external workflow control.
|
||||
|
||||
## Components
|
||||
|
||||
### InMemoryChannel
|
||||
|
||||
Thread-safe in-memory queue for single-process deployments.
|
||||
|
||||
- `fetch_commands()` - Get pending commands
|
||||
- `send_command()` - Add command to queue
|
||||
|
||||
### RedisChannel
|
||||
|
||||
Redis-based queue for distributed deployments.
|
||||
|
||||
- `fetch_commands()` - Get commands with JSON deserialization
|
||||
- `send_command()` - Store commands with TTL
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Local execution
|
||||
channel = InMemoryChannel()
|
||||
channel.send_command(AbortCommand(graph_id="workflow-123"))
|
||||
|
||||
# Distributed execution
|
||||
redis_channel = RedisChannel(
|
||||
redis_client=redis_client,
|
||||
channel_key="workflow:123:commands"
|
||||
)
|
||||
```
|
||||
@ -0,0 +1,6 @@
|
||||
"""Command channel implementations for GraphEngine."""
|
||||
|
||||
from .in_memory_channel import InMemoryChannel
|
||||
from .redis_channel import RedisChannel
|
||||
|
||||
__all__ = ["InMemoryChannel", "RedisChannel"]
|
||||
@ -0,0 +1,53 @@
|
||||
"""
|
||||
In-memory implementation of CommandChannel for local/testing scenarios.
|
||||
|
||||
This implementation uses a thread-safe queue for command communication
|
||||
within a single process. Each instance handles commands for one workflow execution.
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
from typing import final
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryChannel:
|
||||
"""
|
||||
In-memory command channel implementation using a thread-safe queue.
|
||||
|
||||
Each instance is dedicated to a single GraphEngine/workflow execution.
|
||||
Suitable for local development, testing, and single-instance deployments.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory channel with a single queue."""
|
||||
self._queue: Queue[GraphEngineCommand] = Queue()
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from the queue.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the queue)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Drain all available commands from the queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
command = self._queue.get_nowait()
|
||||
commands.append(command)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to this channel's queue.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
self._queue.put(command)
|
||||
114
api/core/workflow/graph_engine/command_channels/redis_channel.py
Normal file
114
api/core/workflow/graph_engine/command_channels/redis_channel.py
Normal file
@ -0,0 +1,114 @@
|
||||
"""
|
||||
Redis-based implementation of CommandChannel for distributed scenarios.
|
||||
|
||||
This implementation uses Redis lists for command queuing, supporting
|
||||
multi-instance deployments and cross-server communication.
|
||||
Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
|
||||
@final
|
||||
class RedisChannel:
|
||||
"""
|
||||
Redis-based command channel implementation for distributed systems.
|
||||
|
||||
Each instance uses a unique Redis key for its command queue.
|
||||
Commands are JSON-serialized for transport.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: "RedisClientWrapper",
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Redis channel.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
channel_key: Unique key for this channel's command queue
|
||||
command_ttl: TTL for command keys in seconds (default: 3600)
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._key = channel_key
|
||||
self._command_ttl = command_ttl
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from Redis.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the Redis list)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
with self._redis.pipeline() as pipe:
|
||||
# Get all commands and clear the list atomically
|
||||
pipe.lrange(self._key, 0, -1)
|
||||
pipe.delete(self._key)
|
||||
results = pipe.execute()
|
||||
|
||||
# Parse commands from JSON
|
||||
if results[0]:
|
||||
for command_json in results[0]:
|
||||
try:
|
||||
command_data = json.loads(command_json)
|
||||
command = self._deserialize_command(command_data)
|
||||
if command:
|
||||
commands.append(command)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Skip invalid commands
|
||||
continue
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to Redis.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
command_json = json.dumps(command.model_dump())
|
||||
|
||||
# Push to list and set expiry
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.rpush(self._key, command_json)
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
Args:
|
||||
data: Command data dictionary
|
||||
|
||||
Returns:
|
||||
Deserialized command or None if invalid
|
||||
"""
|
||||
command_type_value = data.get("command_type")
|
||||
if not isinstance(command_type_value, str):
|
||||
return None
|
||||
|
||||
try:
|
||||
command_type = CommandType(command_type_value)
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand(**data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand(**data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Command processing subsystem for graph engine.
|
||||
|
||||
This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
]
|
||||
@ -0,0 +1,32 @@
|
||||
"""
|
||||
Command handler implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
"""
|
||||
Handle an abort command.
|
||||
|
||||
Args:
|
||||
command: The abort command
|
||||
execution: Graph execution to abort
|
||||
"""
|
||||
assert isinstance(command, AbortCommand)
|
||||
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.abort(command.reason or "User requested abort")
|
||||
@ -0,0 +1,79 @@
|
||||
"""
|
||||
Main command processor for handling external commands.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Protocol, final
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
from ..protocols.command_channel import CommandChannel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandHandler(Protocol):
|
||||
"""Protocol for command handlers."""
|
||||
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
|
||||
|
||||
|
||||
@final
|
||||
class CommandProcessor:
|
||||
"""
|
||||
Processes external commands sent to the engine.
|
||||
|
||||
This polls the command channel and dispatches commands to
|
||||
appropriate handlers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command_channel: CommandChannel,
|
||||
graph_execution: GraphExecution,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the command processor.
|
||||
|
||||
Args:
|
||||
command_channel: Channel for receiving commands
|
||||
graph_execution: Graph execution aggregate
|
||||
"""
|
||||
self._command_channel = command_channel
|
||||
self._graph_execution = graph_execution
|
||||
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
|
||||
|
||||
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
|
||||
"""
|
||||
Register a handler for a command type.
|
||||
|
||||
Args:
|
||||
command_type: Type of command to handle
|
||||
handler: Handler for the command
|
||||
"""
|
||||
self._handlers[command_type] = handler
|
||||
|
||||
def process_commands(self) -> None:
|
||||
"""Check for and process any pending commands."""
|
||||
try:
|
||||
commands = self._command_channel.fetch_commands()
|
||||
for command in commands:
|
||||
self._handle_command(command)
|
||||
except Exception as e:
|
||||
logger.warning("Error processing commands: %s", e)
|
||||
|
||||
def _handle_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Handle a single command.
|
||||
|
||||
Args:
|
||||
command: The command to handle
|
||||
"""
|
||||
handler = self._handlers.get(type(command))
|
||||
if handler:
|
||||
try:
|
||||
handler.handle(command, self._graph_execution)
|
||||
except Exception:
|
||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||
else:
|
||||
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
||||
@ -1,25 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user