Merge branch 'main' into feat/r2

# Conflicts:
#	api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
#	api/core/workflow/entities/node_entities.py
#	api/core/workflow/enums.py
This commit is contained in:
jyong
2025-06-03 18:56:49 +08:00
277 changed files with 4440 additions and 1927 deletions

View File

@ -1,37 +1,10 @@
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from models.workflow import WorkflowNodeExecutionStatus
class NodeRunMetadataKey(StrEnum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
DATASOURCE_INFO = "datasource_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class NodeRunResult(BaseModel):
@ -44,7 +17,7 @@ class NodeRunResult(BaseModel):
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[Mapping[str, Any]] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches

View File

@ -37,12 +37,10 @@ class WorkflowExecution(BaseModel):
user, tenant, and app attributes.
"""
id: str = Field(...)
id_: str = Field(...)
workflow_id: str = Field(...)
workflow_version: str = Field(...)
sequence_number: int = Field(...)
type: WorkflowType = Field(...)
workflow_type: WorkflowType = Field(...)
graph: Mapping[str, Any] = Field(...)
inputs: Mapping[str, Any] = Field(...)
@ -70,20 +68,18 @@ class WorkflowExecution(BaseModel):
def new(
cls,
*,
id: str,
id_: str,
workflow_id: str,
sequence_number: int,
type: WorkflowType,
workflow_type: WorkflowType,
workflow_version: str,
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> "WorkflowExecution":
return WorkflowExecution(
id=id,
id_=id_,
workflow_id=workflow_id,
sequence_number=sequence_number,
type=type,
workflow_type=workflow_type,
workflow_version=workflow_version,
graph=graph,
inputs=inputs,

View File

@ -13,11 +13,35 @@ from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
class NodeExecutionStatus(StrEnum):
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
class WorkflowNodeExecutionStatus(StrEnum):
"""
Node Execution Status Enum.
"""
@ -29,7 +53,7 @@ class NodeExecutionStatus(StrEnum):
RETRY = "retry"
class NodeExecution(BaseModel):
class WorkflowNodeExecution(BaseModel):
"""
Domain model for workflow node execution.
@ -46,7 +70,7 @@ class NodeExecution(BaseModel):
id: str # Unique identifier for this execution record
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
workflow_id: str # ID of the workflow this node belongs to
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization
@ -61,12 +85,12 @@ class NodeExecution(BaseModel):
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
# Execution state
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
error: Optional[str] = None # Error message if execution failed
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
# Additional metadata
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
# Timing information
created_at: datetime # When execution started
@ -77,7 +101,7 @@ class NodeExecution(BaseModel):
inputs: Optional[Mapping[str, Any]] = None,
process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
) -> None:
"""
Update the model from mappings.

View File

@ -13,7 +13,7 @@ class SystemVariableKey(StrEnum):
DIALOGUE_COUNT = "dialogue_count"
APP_ID = "app_id"
WORKFLOW_ID = "workflow_id"
WORKFLOW_RUN_ID = "workflow_run_id"
WORKFLOW_EXECUTION_ID = "workflow_run_id"
# RAG Pipeline
DOCUMENT_ID = "document_id"
BATCH = "batch"

View File

@ -1,9 +1,10 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")

View File

@ -6,7 +6,7 @@ from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
class RouteNodeState(BaseModel):

View File

@ -14,8 +14,9 @@ from flask import Flask, current_app, has_request_context
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseAgentEvent,
@ -52,9 +53,8 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from models.workflow import WorkflowType
logger = logging.getLogger(__name__)
@ -606,8 +606,6 @@ class GraphEngine:
error=str(e),
)
)
finally:
db.session.remove()
def _run_node(
self,
@ -645,7 +643,6 @@ class GraphEngine:
agent_strategy=agent_strategy,
)
db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0
@ -759,10 +756,12 @@ class GraphEngine:
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
if run_result.metadata and run_result.metadata.get(
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS
):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
@ -785,13 +784,17 @@ class GraphEngine:
if parallel_id and parallel_start_node_id:
metadata_dict = dict(run_result.metadata)
metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_ID] = parallel_id
metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_START_NODE_ID] = (
parallel_start_node_id
)
if parent_parallel_id and parent_parallel_start_node_id:
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
metadata_dict[WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_ID] = (
parent_parallel_id
)
metadata_dict[
WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_START_NODE_ID
] = parent_parallel_start_node_id
run_result.metadata = metadata_dict
yield NodeRunSucceededEvent(
@ -856,8 +859,6 @@ class GraphEngine:
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""
@ -923,7 +924,7 @@ class GraphEngine:
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
},
}

View File

@ -2,6 +2,9 @@ import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
@ -15,6 +18,7 @@ from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData
@ -25,7 +29,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories.agent_factory import get_plugin_agent_strategy
from models.model import Conversation
from models.workflow import WorkflowNodeExecutionStatus
class AgentNode(ToolNode):
@ -320,15 +323,12 @@ class AgentNode(ToolNode):
return None
conversation_id = conversation_id_variable.value
# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)

View File

@ -3,6 +3,7 @@ from typing import Any, cast
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
@ -13,7 +14,6 @@ from core.workflow.nodes.answer.entities import (
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode[AnswerNodeData]):

View File

@ -4,9 +4,9 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
from .entities import BaseNodeData

View File

@ -8,10 +8,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .exc import (
CodeNodeError,

View File

@ -26,9 +26,9 @@ from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError

View File

@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class EndNode(BaseNode[EndNodeData]):

View File

@ -1,10 +1,12 @@
from collections.abc import Sequence
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
class RunCompletedEvent(BaseModel):
@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")

View File

@ -8,12 +8,12 @@ from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser
from factories import file_factory
from models.workflow import WorkflowNodeExecutionStatus
from .entities import (
HttpRequestNodeData,

View File

@ -4,12 +4,12 @@ from typing_extensions import deprecated
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus
class IfElseNode(BaseNode[IfElseNodeData]):

View File

@ -12,10 +12,10 @@ from flask import Flask, current_app, has_request_context
from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@ -37,7 +37,6 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from models.workflow import WorkflowNodeExecutionStatus
from .exc import (
InvalidIteratorValueError,
@ -249,8 +248,8 @@ class IterationNode(BaseNode[IterationNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
metadata={
NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
},
)
)
@ -361,16 +360,16 @@ class IterationNode(BaseNode[IterationNodeData]):
event.parallel_mode_run_id = parallel_mode_run_id
iter_metadata = {
NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.ITERATION_INDEX: iter_run_index,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
}
if parallel_mode_run_id:
# for parallel, the specific branch ID is more important than the sequential index
iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
if event.route_node_state.node_run_result:
current_metadata = event.route_node_state.node_run_result.metadata or {}
if NodeRunMetadataKey.ITERATION_ID not in current_metadata:
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
return event

View File

@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.iteration.entities import IterationStartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class IterationStartNode(BaseNode[IterationStartNodeData]):

View File

@ -8,6 +8,7 @@ from typing import Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@ -24,6 +25,7 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
@ -41,7 +43,6 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData, ModelConfig
@ -95,14 +96,15 @@ class KnowledgeRetrievalNode(LLMNode):
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
with Session(db.engine) as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@ -173,7 +175,9 @@ class KnowledgeRetrievalNode(LLMNode):
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
@ -424,7 +428,7 @@ class KnowledgeRetrievalNode(LLMNode):
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
model_instance, model_config = self.get_model_config(metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
@ -550,14 +554,7 @@ class KnowledgeRetrievalNode(LLMNode):
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
return variable_mapping
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
"""
Fetch model config
:param model: model
:return:
"""
if model is None:
raise ValueError("model is required")
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = model.name
provider_name = model.provider

View File

@ -4,9 +4,9 @@ from typing import Any, Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError

View File

@ -7,6 +7,8 @@ from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast
import json_repair
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@ -43,6 +45,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import (
ArrayAnySegment,
ArrayFileSegment,
@ -53,9 +56,10 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
@ -77,7 +81,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import (
LLMNodeChatModelMessage,
@ -267,9 +270,9 @@ class LLMNode(BaseNode[LLMNodeData]):
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
@ -302,8 +305,6 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]:
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource = []
original_retriever_resource: list[RetrievalSourceMetadata] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]):
retriever_resources=original_retriever_resource, context=context_str.strip()
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
def _convert_to_original_retriever_resource(self, context_dict: dict):
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]):
):
metadata = context_dict.get("metadata", {})
source = {
"position": metadata.get("position"),
"dataset_id": metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"),
}
source = RetrievalSourceMetadata(
position=metadata.get("position"),
dataset_id=metadata.get("dataset_id"),
dataset_name=metadata.get("dataset_name"),
document_id=metadata.get("document_id"),
document_name=metadata.get("document_name"),
data_source_type=metadata.get("data_source_type"),
segment_id=metadata.get("segment_id"),
retriever_from=metadata.get("retriever_from"),
score=metadata.get("score"),
hit_count=metadata.get("segment_hit_count"),
word_count=metadata.get("segment_word_count"),
segment_position=metadata.get("segment_position"),
index_node_hash=metadata.get("segment_index_node_hash"),
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
)
return source
@ -602,15 +603,11 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
conversation_id = conversation_id_variable.value
# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
if not conversation:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
@ -846,20 +843,24 @@ class LLMNode(BaseNode[LLMNodeData]):
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
).update(
{
"quota_used": Provider.quota_used + used_quota,
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
}
)
db.session.commit()
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(

View File

@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData
from models.workflow import WorkflowNodeExecutionStatus
class LoopEndNode(BaseNode[LoopEndNodeData]):

View File

@ -15,7 +15,8 @@ from core.variables import (
SegmentType,
StringSegment,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@ -37,7 +38,6 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
@ -187,10 +187,10 @@ class LoopNode(BaseNode[LoopNodeData]):
outputs=self.node_data.outputs,
steps=loop_count,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "loop_break" if check_break_result else "loop_completed",
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
@ -198,9 +198,9 @@ class LoopNode(BaseNode[LoopNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self.node_data.outputs,
inputs=inputs,
@ -221,8 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]):
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
error=str(e),
)
@ -232,9 +232,9 @@ class LoopNode(BaseNode[LoopNodeData]):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
)
@ -322,7 +322,9 @@ class LoopNode(BaseNode[LoopNodeData]):
inputs=inputs,
steps=current_index,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
),
"completed_reason": "error",
},
error=event.error,
@ -331,7 +333,11 @@ class LoopNode(BaseNode[LoopNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
)
},
)
)
return {"check_break_result": True}
@ -347,7 +353,7 @@ class LoopNode(BaseNode[LoopNodeData]):
inputs=inputs,
steps=current_index,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=event.error,
@ -356,7 +362,9 @@ class LoopNode(BaseNode[LoopNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
},
)
)
return {"check_break_result": True}
@ -411,11 +419,11 @@ class LoopNode(BaseNode[LoopNodeData]):
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.LOOP_ID not in metadata:
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata:
metadata = {
**metadata,
NodeRunMetadataKey.LOOP_ID: self.node_id,
NodeRunMetadataKey.LOOP_INDEX: iter_run_index,
WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
return event

View File

@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopStartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class LoopStartNode(BaseNode[LoopStartNodeData]):

View File

@ -25,13 +25,12 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig
from core.workflow.utils import variable_template_parser
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ParameterExtractorNodeData
from .exc import (
@ -244,9 +243,9 @@ class ParameterExtractorNode(LLMNode):
process_data=process_data,
outputs={"__is_success": 1 if not error else 0, "__reason": error, **result},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode):
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
@ -816,7 +813,6 @@ class ParameterExtractorNode(LLMNode):
:param node_data: node data
:return:
"""
# FIXME: fix the type error later
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
if node_data.instruction:

View File

@ -10,7 +10,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
@ -20,7 +21,6 @@ from core.workflow.nodes.llm import (
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus
from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
@ -79,9 +79,13 @@ class QuestionClassifierNode(LLMNode):
memory=memory,
max_token_limit=rest_token,
)
# Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...).
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
sys_query="",
memory=memory,
model_config=model_config,
sys_files=files,
@ -142,9 +146,9 @@ class QuestionClassifierNode(LLMNode):
outputs=outputs,
edge_source_handle=category_id,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
@ -154,9 +158,9 @@ class QuestionClassifierNode(LLMNode):
inputs=variables,
error=str(e),
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)

View File

@ -1,9 +1,9 @@
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class StartNode(BaseNode[StartNodeData]):

View File

@ -4,10 +4,10 @@ from typing import Any, Optional
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))

View File

@ -14,8 +14,9 @@ from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
@ -25,7 +26,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData
@ -70,7 +70,7 @@ class ToolNode(BaseNode[ToolNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to get tool runtime: {str(e)}",
error_type=type(e).__name__,
)
@ -110,7 +110,7 @@ class ToolNode(BaseNode[ToolNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)
@ -125,7 +125,7 @@ class ToolNode(BaseNode[ToolNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}",
error_type=type(e).__name__,
)
@ -201,7 +201,7 @@ class ToolNode(BaseNode[ToolNodeData]):
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
variables: dict[str, Any] = {}
@ -274,7 +274,7 @@ class ToolNode(BaseNode[ToolNodeData]):
agent_execution_metadata = {
key: value
for key, value in msg_metadata.items()
if key in NodeRunMetadataKey.__members__.values()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
@ -366,8 +366,8 @@ class ToolNode(BaseNode[ToolNodeData]):
outputs={"text": text, "files": files, "json": json, **variables},
metadata={
**agent_execution_metadata,
NodeRunMetadataKey.TOOL_INFO: tool_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
)

View File

@ -1,7 +1,8 @@
from typing import Literal, Optional
from typing import Optional
from pydantic import BaseModel
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
@ -17,7 +18,7 @@ class AdvancedSettings(BaseModel):
Group.
"""
output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
output_type: SegmentType
variables: list[list[str]]
group_name: str

View File

@ -1,8 +1,8 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):

View File

@ -1,11 +1,11 @@
from core.variables import SegmentType, Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
from models.workflow import WorkflowNodeExecutionStatus
from .node_data import VariableAssignerData, WriteMode

View File

@ -6,11 +6,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from models.workflow import WorkflowNodeExecutionStatus
from . import helpers
from .constants import EMPTY_VALUE_MAPPING

View File

@ -6,7 +6,7 @@ for accessing and manipulating data, regardless of the underlying
storage mechanism.
"""
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
__all__ = [
"OrderConfig",

View File

@ -1,6 +1,6 @@
from typing import Optional, Protocol
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.entities.workflow_execution import WorkflowExecution
class WorkflowExecutionRepository(Protocol):

View File

@ -2,7 +2,7 @@ from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Optional, Protocol
from core.workflow.entities.node_execution_entities import NodeExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
@dataclass
@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol):
application domains or deployment scenarios.
"""
def save(self, execution: NodeExecution) -> None:
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a NodeExecution instance.
@ -39,7 +39,7 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.
@ -55,7 +55,7 @@ class WorkflowNodeExecutionRepository(Protocol):
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[NodeExecution]:
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
@ -70,7 +70,7 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.

View File

@ -1,11 +1,9 @@
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any, Optional, Union
from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
@ -19,21 +17,24 @@ from core.app.entities.queue_entities import (
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from models import (
Workflow,
WorkflowRun,
WorkflowRunStatus,
)
@dataclass
class CycleManagerWorkflowInfo:
workflow_id: str
workflow_type: WorkflowType
version: str
graph_data: Mapping[str, Any]
class WorkflowCycleManager:
@ -42,32 +43,17 @@ class WorkflowCycleManager:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def handle_workflow_run_start(
self,
*,
session: Session,
workflow_id: str,
) -> WorkflowExecution:
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.scalar(workflow_stmt)
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
WorkflowRun.tenant_id == workflow.tenant_id,
WorkflowRun.app_id == workflow.app_id,
)
max_sequence = session.scalar(max_sequence_stmt) or 0
new_sequence_number = max_sequence + 1
def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == "conversation":
@ -79,14 +65,13 @@ class WorkflowCycleManager:
# init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4())
execution = WorkflowExecution.new(
id=execution_id,
workflow_id=workflow.id,
sequence_number=new_sequence_number,
type=WorkflowType(workflow.type),
workflow_version=workflow.version,
graph=workflow.graph_dict,
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
workflow_type=self._workflow_info.workflow_type,
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=inputs,
started_at=datetime.now(UTC).replace(tzinfo=None),
)
@ -168,7 +153,7 @@ class WorkflowCycleManager:
workflow_run_id: str,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
status: WorkflowExecutionStatus,
error_message: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
@ -185,7 +170,7 @@ class WorkflowCycleManager:
# Use the instance repository to find running executions for a workflow run
running_node_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_execution.id
workflow_run_id=workflow_execution.id_
)
# Update the domain models
@ -193,7 +178,7 @@ class WorkflowCycleManager:
for node_execution in running_node_executions:
if node_execution.node_execution_id:
# Update the domain model
node_execution.status = NodeExecutionStatus.FAILED
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
@ -219,28 +204,28 @@ class WorkflowCycleManager:
*,
workflow_execution_id: str,
event: QueueNodeStartedEvent,
) -> NodeExecution:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
# Create a domain model
created_at = datetime.now(UTC).replace(tzinfo=None)
metadata = {
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = NodeExecution(
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_run_id=workflow_execution.id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=NodeExecutionStatus.RUNNING,
status=WorkflowNodeExecutionStatus.RUNNING,
metadata=metadata,
created_at=created_at,
)
@ -250,7 +235,7 @@ class WorkflowCycleManager:
return domain_execution
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
@ -271,7 +256,7 @@ class WorkflowCycleManager:
elapsed_time = (finished_at - event.start_at).total_seconds()
# Update domain model
domain_execution.status = NodeExecutionStatus.SUCCEEDED
domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
)
@ -290,7 +275,7 @@ class WorkflowCycleManager:
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> NodeExecution:
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
@ -317,9 +302,9 @@ class WorkflowCycleManager:
# Update domain model
domain_execution.status = (
NodeExecutionStatus.FAILED
WorkflowNodeExecutionStatus.FAILED
if not isinstance(event, QueueNodeExceptionEvent)
else NodeExecutionStatus.EXCEPTION
else WorkflowNodeExecutionStatus.EXCEPTION
)
domain_execution.error = event.error
domain_execution.update_from_mapping(
@ -335,7 +320,7 @@ class WorkflowCycleManager:
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> NodeExecution:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
@ -345,13 +330,13 @@ class WorkflowCycleManager:
# Convert metadata keys to strings
origin_metadata = {
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
# Convert execution metadata keys to strings
execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {}
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
@ -359,16 +344,16 @@ class WorkflowCycleManager:
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
# Create a domain model
domain_execution = NodeExecution(
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_run_id=workflow_execution.id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=NodeExecutionStatus.RETRY,
status=WorkflowNodeExecutionStatus.RETRY,
created_at=created_at,
finished_at=finished_at,
elapsed_time=elapsed_time,