Merge branch 'feat/mcp-06-18' into deploy/dev

This commit is contained in:
Novice
2025-10-23 17:20:21 +08:00
466 changed files with 31391 additions and 7797 deletions

View File

@ -1,3 +1,5 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -3,11 +3,12 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
@ -201,6 +202,17 @@ class Graph:
return GraphBuilder(graph_cls=cls)
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
:param nodes: mapping of node ID to node instance
"""
for node in nodes.values():
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
node.execution_type = NodeExecutionType.BRANCH
@classmethod
def _mark_inactive_root_branches(
cls,
@ -307,6 +319,9 @@ class Graph:
# Create node instances
nodes = cls._create_node_instances(node_configs_map, node_factory)
# Promote fail-branch nodes to branch execution type at graph level
cls._promote_fail_branch_nodes(nodes)
# Get root node instance
root_node = nodes[root_node_id]
@ -314,7 +329,7 @@ class Graph:
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph
return cls(
graph = cls(
nodes=nodes,
edges=edges,
in_edges=in_edges,
@ -322,6 +337,11 @@ class Graph:
root_node=root_node,
)
# Validate the graph structure using built-in validators
get_graph_validator().validate(graph)
return graph
@property
def node_ids(self) -> list[str]:
"""

View File

@ -0,0 +1,125 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from core.workflow.enums import NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@dataclass(frozen=True, slots=True)
class GraphValidationIssue:
"""Immutable value object describing a single validation issue."""
code: str
message: str
node_id: str | None = None
class GraphValidationError(ValueError):
"""Raised when graph validation fails."""
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
if not issues:
raise ValueError("GraphValidationError requires at least one issue.")
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
super().__init__(message)
class GraphValidationRule(Protocol):
"""Protocol that individual validation rules must satisfy."""
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
"""Validate the provided graph and return any discovered issues."""
...
@dataclass(frozen=True, slots=True)
class _EdgeEndpointValidator:
"""Ensures all edges reference existing nodes."""
missing_node_code: str = "MISSING_NODE"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
issues: list[GraphValidationIssue] = []
for edge in graph.edges.values():
if edge.tail not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
node_id=edge.tail,
)
)
if edge.head not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
node_id=edge.head,
)
)
return issues
@dataclass(frozen=True, slots=True)
class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
issues: list[GraphValidationIssue] = []
if root_node.id not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' is missing from the node registry.",
node_id=root_node.id,
)
)
return issues
node_type = getattr(root_node, "node_type", None)
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
node_id=root_node.id,
)
)
return issues
@dataclass(frozen=True, slots=True)
class GraphValidator:
"""Coordinates execution of graph validation rules."""
rules: tuple[GraphValidationRule, ...]
def validate(self, graph: Graph) -> None:
"""Validate the graph against all configured rules."""
issues: list[GraphValidationIssue] = []
for rule in self.rules:
issues.extend(rule.validate(graph))
if issues:
raise GraphValidationError(issues)
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
)
def get_graph_validator() -> GraphValidator:
"""Construct the validator composed of default rules."""
return GraphValidator(_DEFAULT_RULES)

View File

@ -26,8 +26,8 @@ class AgentNodeData(BaseNodeData):
class ParamsAutoGenerated(IntEnum):
CLOSE = auto()
OPEN = auto()
CLOSE = 0
OPEN = 1
class AgentOldVersionModelFeatures(StrEnum):

View File

@ -1,4 +1,5 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [
"BaseIterationNodeData",
@ -6,4 +7,5 @@ __all__ = [
"BaseLoopNodeData",
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
]

View File

@ -1,5 +1,6 @@
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:

View File

@ -0,0 +1,28 @@
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState
class LLMUsageTrackingMixin:
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
graph_runtime_state: GraphRuntimeState
@staticmethod
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
"""Return a combined usage snapshot, preserving zero-value inputs."""
if new_usage is None or new_usage.total_tokens <= 0:
return current
if current.total_tokens == 0:
return new_usage
return current.plus(new_usage)
def _accumulate_usage(self, usage: LLMUsage) -> None:
"""Push usage into the graph runtime accumulator for downstream reporting."""
if usage.total_tokens <= 0:
return
current_usage = self.graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self.graph_runtime_state.llm_usage = usage.model_copy()
else:
self.graph_runtime_state.llm_usage = current_usage.plus(usage)

View File

@ -10,10 +10,10 @@ from typing import Any
import chardet
import docx
import pandas as pd
import pypandoc # type: ignore
import pypdfium2 # type: ignore
import webvtt # type: ignore
import yaml # type: ignore
import pypandoc
import pypdfium2
import webvtt
import yaml
from docx.document import Document
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
@ -34,6 +35,7 @@ from core.workflow.node_events import (
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(Node):
class IterationNode(LLMUsageTrackingMixin, Node):
"""
Iteration Node.
"""
@ -118,6 +120,7 @@ class IterationNode(Node):
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
outputs: list[object] = []
usage_accumulator = [LLMUsage.empty_usage()]
yield IterationStartedEvent(
start_at=started_at,
@ -130,22 +133,27 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_success(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
)
except IterationNodeError as e:
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_failure(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
error=e,
)
@ -196,6 +204,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
# Parallel mode execution
@ -203,6 +212,7 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
else:
# Sequential mode execution
@ -228,6 +238,9 @@ class IterationNode(Node):
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
usage_accumulator[0] = self._merge_usage(
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
)
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _execute_parallel_iterations(
@ -235,6 +248,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value))
@ -245,7 +259,16 @@ class IterationNode(Node):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
future_to_index: dict[
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
Future[
tuple[
datetime,
list[GraphNodeEventBase],
object | None,
int,
dict[str, VariableUnion],
LLMUsage,
]
],
int,
] = {}
for index, item in enumerate(iterator_list_value):
@ -264,7 +287,14 @@ class IterationNode(Node):
index = future_to_index[future]
try:
result = future.result()
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
(
iter_start_at,
events,
output_value,
tokens_used,
conversation_snapshot,
iteration_usage,
) = result
# Update outputs at the correct index
outputs[index] = output_value
@ -276,6 +306,8 @@ class IterationNode(Node):
self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
# Sync conversation variables after iteration completion
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
@ -303,7 +335,7 @@ class IterationNode(Node):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -332,6 +364,7 @@ class IterationNode(Node):
output_value,
graph_engine.graph_runtime_state.total_tokens,
conversation_snapshot,
graph_engine.graph_runtime_state.llm_usage,
)
def _handle_iteration_success(
@ -341,6 +374,8 @@ class IterationNode(Node):
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs)
@ -351,7 +386,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
@ -362,8 +399,11 @@ class IterationNode(Node):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": flattened_outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)
@ -400,6 +440,8 @@ class IterationNode(Node):
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case)
@ -411,7 +453,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(error),
@ -420,6 +464,12 @@ class IterationNode(Node):
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(error),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)

View File

@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import (
PromptMessageRole,
)
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelType,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
@ -33,8 +30,14 @@ from core.variables import (
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import (
ErrorStrategy,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
@ -80,7 +83,7 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(Node):
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
@ -141,7 +144,7 @@ class KnowledgeRetrievalNode(Node):
def version(cls):
return "1"
def _run(self) -> NodeRunResult: # type: ignore
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node):
)
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data={},
process_data={"usage": jsonable_encoder(usage)},
outputs=outputs, # type: ignore
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
except KnowledgeRetrievalNodeError as e:
@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node):
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node):
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
finally:
db.session.close()
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, query: str
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
dataset_ids = node_data.dataset_ids
@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node):
if not dataset:
continue
available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node):
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
return retrieval_resource_list
return retrieval_resource_list, usage
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
usage = LLMUsage.empty_usage()
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node):
filters: list[Any] = []
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None
return None, None, usage
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
@ -443,7 +465,7 @@ class KnowledgeRetrievalNode(Node):
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or", # type: ignore
else "or",
conditions=conditions,
)
elif node_data.metadata_filtering_mode == "manual":
@ -457,10 +479,10 @@ class KnowledgeRetrievalNode(Node):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
@ -487,7 +509,7 @@ class KnowledgeRetrievalNode(Node):
if (
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore
):
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))
@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
return metadata_filter_document_ids, metadata_condition, usage
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
# get all metadata field
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node):
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = self._merge_usage(usage, event.usage)
break
result_text_json = parse_and_check_json_markdown(result_text, [])
@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node):
}
)
except Exception:
return []
return automatic_metadata_filters
return [], usage
return automatic_metadata_filters, usage
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]

View File

@ -452,10 +452,14 @@ class LLMNode(Node):
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
collected_structured_output = None # Collect structured_output from streaming chunks
# Consume the invoke result and handle generator exception
try:
for result in invoke_result:
if isinstance(result, LLMResultChunkWithStructuredOutput):
# Collect structured_output from the chunk
if result.structured_output is not None:
collected_structured_output = dict(result.structured_output)
yield result
if isinstance(result, LLMResultChunk):
contents = result.delta.message.content
@ -503,6 +507,8 @@ class LLMNode(Node):
finish_reason=finish_reason,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if collected from streaming chunks
structured_output=collected_structured_output,
)
@staticmethod

View File

@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
@ -27,6 +28,7 @@ from core.workflow.node_events import (
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
@ -40,7 +42,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LoopNode(Node):
class LoopNode(LLMUsageTrackingMixin, Node):
"""
Loop Node.
"""
@ -108,7 +110,7 @@ class LoopNode(Node):
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self._node_id, loop_variable.label]
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
self.graph_runtime_state.variable_pool.add(variable_selector, variable.value)
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
@ -117,6 +119,7 @@ class LoopNode(Node):
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
# Start Loop event
yield LoopStartedEvent(
@ -163,6 +166,9 @@ class LoopNode(Node):
# Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens
# Accumulate usage from the sub-graph execution
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
@ -189,6 +195,7 @@ class LoopNode(Node):
)
self.graph_runtime_state.total_tokens += cost_tokens
self._accumulate_usage(loop_usage)
# Loop completed successfully
yield LoopSucceededEvent(
start_at=start_at,
@ -196,7 +203,9 @@ class LoopNode(Node):
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -207,22 +216,28 @@ class LoopNode(Node):
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self._node_data.outputs,
inputs=inputs,
llm_usage=loop_usage,
)
)
except Exception as e:
self._accumulate_usage(loop_usage)
yield LoopFailedEvent(
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "error",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -235,10 +250,13 @@ class LoopNode(Node):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
llm_usage=loop_usage,
)
)

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data)
# If node has fail branch, change execution type to branch
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
node_instance.execution_type = NodeExecutionType.BRANCH
return node_instance

View File

@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]

View File

@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
### Instructions:
Some extra information are provided below, you should always follow the instructions as possible as you can.
<instructions>
{{instructions}}
{instructions}
</instructions>
"""

View File

@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
@ -136,13 +139,14 @@ class ToolNode(Node):
try:
# convert tool messages
yield from self._transform_message(
_ = yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self._node_id,
tool_runtime=tool_runtime,
)
except ToolInvokeError as e:
yield StreamCompletedEvent(
@ -236,7 +240,8 @@ class ToolNode(Node):
user_id: str,
tenant_id: str,
node_id: str,
) -> Generator:
tool_runtime: Tool,
) -> Generator[NodeEventBase, None, LLMUsage]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@ -424,17 +429,34 @@ class ToolNode(Node):
is_final=True,
)
usage = self._extract_tool_usage(tool_runtime)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
}
if usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
},
metadata=metadata,
inputs=parameters_for_log,
llm_usage=usage,
)
)
return usage
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
if isinstance(tool_runtime, WorkflowTool):
return tool_runtime.latest_usage
return LLMUsage.empty_usage()
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -277,7 +277,7 @@ class VariablePool(BaseModel):
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value) # type: ignore
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":