mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
Merge branch 'feat/mcp-06-18' into deploy/dev
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||
from ..runtime.variable_pool import VariablePool
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
@ -3,11 +3,12 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -201,6 +202,17 @@ class Graph:
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@classmethod
|
||||
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||
"""
|
||||
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
"""
|
||||
for node in nodes.values():
|
||||
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
@ -307,6 +319,9 @@ class Graph:
|
||||
# Create node instances
|
||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||
|
||||
# Promote fail-branch nodes to branch execution type at graph level
|
||||
cls._promote_fail_branch_nodes(nodes)
|
||||
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
@ -314,7 +329,7 @@ class Graph:
|
||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
graph = cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=in_edges,
|
||||
@ -322,6 +337,11 @@ class Graph:
|
||||
root_node=root_node,
|
||||
)
|
||||
|
||||
# Validate the graph structure using built-in validators
|
||||
get_graph_validator().validate(graph)
|
||||
|
||||
return graph
|
||||
|
||||
@property
|
||||
def node_ids(self) -> list[str]:
|
||||
"""
|
||||
|
||||
125
api/core/workflow/graph/validation.py
Normal file
125
api/core/workflow/graph/validation.py
Normal file
@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidationIssue:
|
||||
"""Immutable value object describing a single validation issue."""
|
||||
|
||||
code: str
|
||||
message: str
|
||||
node_id: str | None = None
|
||||
|
||||
|
||||
class GraphValidationError(ValueError):
|
||||
"""Raised when graph validation fails."""
|
||||
|
||||
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
||||
if not issues:
|
||||
raise ValueError("GraphValidationError requires at least one issue.")
|
||||
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
||||
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class GraphValidationRule(Protocol):
|
||||
"""Protocol that individual validation rules must satisfy."""
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
"""Validate the provided graph and return any discovered issues."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _EdgeEndpointValidator:
|
||||
"""Ensures all edges reference existing nodes."""
|
||||
|
||||
missing_node_code: str = "MISSING_NODE"
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for edge in graph.edges.values():
|
||||
if edge.tail not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
||||
node_id=edge.tail,
|
||||
)
|
||||
)
|
||||
if edge.head not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
||||
node_id=edge.head,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _RootNodeValidator:
|
||||
"""Validates root node invariants."""
|
||||
|
||||
invalid_root_code: str = "INVALID_ROOT"
|
||||
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
root_node = graph.root_node
|
||||
issues: list[GraphValidationIssue] = []
|
||||
if root_node.id not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
node_type = getattr(root_node, "node_type", None)
|
||||
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidator:
|
||||
"""Coordinates execution of graph validation rules."""
|
||||
|
||||
rules: tuple[GraphValidationRule, ...]
|
||||
|
||||
def validate(self, graph: Graph) -> None:
|
||||
"""Validate the graph against all configured rules."""
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for rule in self.rules:
|
||||
issues.extend(rule.validate(graph))
|
||||
|
||||
if issues:
|
||||
raise GraphValidationError(issues)
|
||||
|
||||
|
||||
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||
_EdgeEndpointValidator(),
|
||||
_RootNodeValidator(),
|
||||
)
|
||||
|
||||
|
||||
def get_graph_validator() -> GraphValidator:
|
||||
"""Construct the validator composed of default rules."""
|
||||
return GraphValidator(_DEFAULT_RULES)
|
||||
@ -26,8 +26,8 @@ class AgentNodeData(BaseNodeData):
|
||||
|
||||
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = auto()
|
||||
OPEN = auto()
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
@ -6,4 +7,5 @@ __all__ = [
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
|
||||
28
api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
28
api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
@ -0,0 +1,28 @@
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class LLMUsageTrackingMixin:
|
||||
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
|
||||
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
@staticmethod
|
||||
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
|
||||
"""Return a combined usage snapshot, preserving zero-value inputs."""
|
||||
if new_usage is None or new_usage.total_tokens <= 0:
|
||||
return current
|
||||
if current.total_tokens == 0:
|
||||
return new_usage
|
||||
return current.plus(new_usage)
|
||||
|
||||
def _accumulate_usage(self, usage: LLMUsage) -> None:
|
||||
"""Push usage into the graph runtime accumulator for downstream reporting."""
|
||||
if usage.total_tokens <= 0:
|
||||
return
|
||||
|
||||
current_usage = self.graph_runtime_state.llm_usage
|
||||
if current_usage.total_tokens == 0:
|
||||
self.graph_runtime_state.llm_usage = usage.model_copy()
|
||||
else:
|
||||
self.graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||
@ -10,10 +10,10 @@ from typing import Any
|
||||
import chardet
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypandoc # type: ignore
|
||||
import pypdfium2 # type: ignore
|
||||
import webvtt # type: ignore
|
||||
import yaml # type: ignore
|
||||
import pypandoc
|
||||
import pypdfium2
|
||||
import webvtt
|
||||
import yaml
|
||||
from docx.document import Document
|
||||
from docx.oxml.table import CT_Tbl
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
@ -34,6 +35,7 @@ from core.workflow.node_events import (
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
|
||||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||
|
||||
|
||||
class IterationNode(Node):
|
||||
class IterationNode(LLMUsageTrackingMixin, Node):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
@ -118,6 +120,7 @@ class IterationNode(Node):
|
||||
started_at = naive_utc_now()
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[object] = []
|
||||
usage_accumulator = [LLMUsage.empty_usage()]
|
||||
|
||||
yield IterationStartedEvent(
|
||||
start_at=started_at,
|
||||
@ -130,22 +133,27 @@ class IterationNode(Node):
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
|
||||
self._accumulate_usage(usage_accumulator[0])
|
||||
yield from self._handle_iteration_success(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
usage=usage_accumulator[0],
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
self._accumulate_usage(usage_accumulator[0])
|
||||
yield from self._handle_iteration_failure(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
usage=usage_accumulator[0],
|
||||
error=e,
|
||||
)
|
||||
|
||||
@ -196,6 +204,7 @@ class IterationNode(Node):
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
if self._node_data.is_parallel:
|
||||
# Parallel mode execution
|
||||
@ -203,6 +212,7 @@ class IterationNode(Node):
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
else:
|
||||
# Sequential mode execution
|
||||
@ -228,6 +238,9 @@ class IterationNode(Node):
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
usage_accumulator[0] = self._merge_usage(
|
||||
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
|
||||
)
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
def _execute_parallel_iterations(
|
||||
@ -235,6 +248,7 @@ class IterationNode(Node):
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
# Initialize outputs list with None values to maintain order
|
||||
outputs.extend([None] * len(iterator_list_value))
|
||||
@ -245,7 +259,16 @@ class IterationNode(Node):
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[
|
||||
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
|
||||
Future[
|
||||
tuple[
|
||||
datetime,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
int,
|
||||
dict[str, VariableUnion],
|
||||
LLMUsage,
|
||||
]
|
||||
],
|
||||
int,
|
||||
] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
@ -264,7 +287,14 @@ class IterationNode(Node):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
|
||||
(
|
||||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
tokens_used,
|
||||
conversation_snapshot,
|
||||
iteration_usage,
|
||||
) = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
@ -276,6 +306,8 @@ class IterationNode(Node):
|
||||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||
|
||||
# Sync conversation variables after iteration completion
|
||||
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
||||
|
||||
@ -303,7 +335,7 @@ class IterationNode(Node):
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@ -332,6 +364,7 @@ class IterationNode(Node):
|
||||
output_value,
|
||||
graph_engine.graph_runtime_state.total_tokens,
|
||||
conversation_snapshot,
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
||||
def _handle_iteration_success(
|
||||
@ -341,6 +374,8 @@ class IterationNode(Node):
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
*,
|
||||
usage: LLMUsage,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists
|
||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||
@ -351,7 +386,9 @@ class IterationNode(Node):
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
)
|
||||
@ -362,8 +399,11 @@ class IterationNode(Node):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": flattened_outputs},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
@ -400,6 +440,8 @@ class IterationNode(Node):
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
*,
|
||||
usage: LLMUsage,
|
||||
error: IterationNodeError,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists (even in failure case)
|
||||
@ -411,7 +453,9 @@ class IterationNode(Node):
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
error=str(error),
|
||||
@ -420,6 +464,12 @@ class IterationNode(Node):
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(error),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageRole,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelFeature,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
@ -33,8 +30,14 @@ from core.variables import (
|
||||
)
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
@ -80,7 +83,7 @@ default_retrieval_model = {
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(Node):
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
_node_data: KnowledgeRetrievalNodeData
|
||||
@ -141,7 +144,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
def _run(self) -> NodeRunResult:
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node):
|
||||
)
|
||||
|
||||
# retrieve knowledge
|
||||
usage = LLMUsage.empty_usage()
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data={},
|
||||
process_data={"usage": jsonable_encoder(usage)},
|
||||
outputs=outputs, # type: ignore
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node):
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, query: str
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
available_datasets = []
|
||||
dataset_ids = node_data.dataset_ids
|
||||
|
||||
@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node):
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
||||
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, metadata_usage)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node):
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node):
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["metadata"]["position"] = position
|
||||
return retrieval_resource_list
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node):
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None, None
|
||||
return None, None, usage
|
||||
elif node_data.metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
@ -443,7 +465,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or", # type: ignore
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
@ -457,10 +479,10 @@ class KnowledgeRetrievalNode(Node):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
|
||||
expected_value = expected_value.value # type: ignore
|
||||
elif expected_value.value_type == "string": # type: ignore
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
@ -487,7 +509,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
): # type: ignore
|
||||
):
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node):
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
for document in documents:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition
|
||||
return metadata_filter_document_ids, metadata_condition, usage
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
# get all metadata field
|
||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(stmt).all()
|
||||
@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = self._merge_usage(usage, event.usage)
|
||||
break
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node):
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
return automatic_metadata_filters
|
||||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
|
||||
@ -452,10 +452,14 @@ class LLMNode(Node):
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
collected_structured_output = None # Collect structured_output from streaming chunks
|
||||
# Consume the invoke result and handle generator exception
|
||||
try:
|
||||
for result in invoke_result:
|
||||
if isinstance(result, LLMResultChunkWithStructuredOutput):
|
||||
# Collect structured_output from the chunk
|
||||
if result.structured_output is not None:
|
||||
collected_structured_output = dict(result.structured_output)
|
||||
yield result
|
||||
if isinstance(result, LLMResultChunk):
|
||||
contents = result.delta.message.content
|
||||
@ -503,6 +507,8 @@ class LLMNode(Node):
|
||||
finish_reason=finish_reason,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
# Pass structured output if collected from streaming chunks
|
||||
structured_output=collected_structured_output,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
@ -27,6 +28,7 @@ from core.workflow.node_events import (
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||
@ -40,7 +42,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(Node):
|
||||
class LoopNode(LLMUsageTrackingMixin, Node):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
@ -108,7 +110,7 @@ class LoopNode(Node):
|
||||
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
||||
variable_selector = [self._node_id, loop_variable.label]
|
||||
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
|
||||
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
|
||||
self.graph_runtime_state.variable_pool.add(variable_selector, variable.value)
|
||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||
inputs[loop_variable.label] = processed_segment.value
|
||||
|
||||
@ -117,6 +119,7 @@ class LoopNode(Node):
|
||||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
loop_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Start Loop event
|
||||
yield LoopStartedEvent(
|
||||
@ -163,6 +166,9 @@ class LoopNode(Node):
|
||||
# Update the total tokens from this iteration
|
||||
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
# Accumulate usage from the sub-graph execution
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
@ -189,6 +195,7 @@ class LoopNode(Node):
|
||||
)
|
||||
|
||||
self.graph_runtime_state.total_tokens += cost_tokens
|
||||
self._accumulate_usage(loop_usage)
|
||||
# Loop completed successfully
|
||||
yield LoopSucceededEvent(
|
||||
start_at=start_at,
|
||||
@ -196,7 +203,9 @@ class LoopNode(Node):
|
||||
outputs=self._node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
@ -207,22 +216,28 @@ class LoopNode(Node):
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self._node_data.outputs,
|
||||
inputs=inputs,
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._accumulate_usage(loop_usage)
|
||||
yield LoopFailedEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
"completed_reason": "error",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
@ -235,10 +250,13 @@ class LoopNode(Node):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
node_instance.init_node_data(node_data)
|
||||
|
||||
# If node has fail branch, change execution type to branch
|
||||
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node_instance.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
return node_instance
|
||||
|
||||
@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
|
||||
@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
|
||||
### Instructions:
|
||||
Some extra information are provided below, you should always follow the instructions as possible as you can.
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
{instructions}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
|
||||
@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
@ -136,13 +139,14 @@ class ToolNode(Node):
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(
|
||||
_ = yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_id=self._node_id,
|
||||
tool_runtime=tool_runtime,
|
||||
)
|
||||
except ToolInvokeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
@ -236,7 +240,8 @@ class ToolNode(Node):
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> Generator:
|
||||
tool_runtime: Tool,
|
||||
) -> Generator[NodeEventBase, None, LLMUsage]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
@ -424,17 +429,34 @@ class ToolNode(Node):
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
usage = self._extract_tool_usage(tool_runtime)
|
||||
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
}
|
||||
if usage.total_tokens > 0:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
metadata=metadata,
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||
if isinstance(tool_runtime, WorkflowTool):
|
||||
return tool_runtime.latest_usage
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@ -277,7 +277,7 @@ class VariablePool(BaseModel):
|
||||
# This ensures that we can keep the id of the system variables intact.
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, value) # type: ignore
|
||||
self.add(selector, value)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "VariablePool":
|
||||
|
||||
Reference in New Issue
Block a user