Merge remote-tracking branch 'upstream/feat/queue-based-graph-engine' into feat/rag-2

This commit is contained in:
QuantumGhost
2025-09-17 18:00:48 +08:00
102 changed files with 2961 additions and 2025 deletions

View File

@ -33,7 +33,13 @@ from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import AgentLogEvent, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
@ -93,7 +99,7 @@ class AgentNode(Node):
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
try:
@ -482,7 +488,7 @@ class AgentNode(Node):
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator:
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""

View File

@ -1,12 +1,13 @@
import logging
from abc import abstractmethod
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from typing import Any, ClassVar
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
@ -45,11 +46,6 @@ from models.enums import UserFrom
from .entities import BaseNodeData, RetryConfig
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeType
from core.workflow.node_events import NodeRunResult
logger = logging.getLogger(__name__)
@ -88,14 +84,14 @@ class Node:
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> "NodeRunResult | Generator[GraphNodeEventBase, None, None]":
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> "Generator[GraphNodeEventBase, None, None]":
def run(self) -> Generator[GraphNodeEventBase, None, None]:
# Generate a single node execution ID to use for all events
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
@ -151,12 +147,14 @@ class Node:
# Handle event stream
for event in result:
if isinstance(event, NodeEventBase):
event = self._convert_node_event_to_graph_node_event(event)
if not event.in_iteration_id and not event.in_loop_id:
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self._node_execution_id
yield event
yield event
else:
yield event
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(
@ -249,7 +247,7 @@ class Node:
return False
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {}
@classmethod
@ -270,7 +268,7 @@ class Node:
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> Optional["ErrorStrategy"]:
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
...
@ -301,7 +299,7 @@ class Node:
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> Optional["ErrorStrategy"]:
def error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
return self._get_error_strategy()
@ -344,29 +342,15 @@ class Node:
start_at=self._start_at,
node_run_result=result,
)
raise Exception(f"result status {result.status} not supported")
case _:
raise Exception(f"result status {result.status} not supported")
def _convert_node_event_to_graph_node_event(self, event: NodeEventBase) -> GraphNodeEventBase:
handler_maps: dict[type[NodeEventBase], Callable[[Any], GraphNodeEventBase]] = {
StreamChunkEvent: self._handle_stream_chunk_event,
StreamCompletedEvent: self._handle_stream_completed_event,
AgentLogEvent: self._handle_agent_log_event,
LoopStartedEvent: self._handle_loop_started_event,
LoopNextEvent: self._handle_loop_next_event,
LoopSucceededEvent: self._handle_loop_succeeded_event,
LoopFailedEvent: self._handle_loop_failed_event,
IterationStartedEvent: self._handle_iteration_started_event,
IterationNextEvent: self._handle_iteration_next_event,
IterationSucceededEvent: self._handle_iteration_succeeded_event,
IterationFailedEvent: self._handle_iteration_failed_event,
RunRetrieverResourceEvent: self._handle_run_retriever_resource_event,
}
handler = handler_maps.get(type(event))
if not handler:
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
return handler(event)
@singledispatchmethod
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
def _handle_stream_chunk_event(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -376,7 +360,8 @@ class Node:
is_final=event.is_final,
)
def _handle_stream_completed_event(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
@_dispatch.register
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
@ -395,9 +380,13 @@ class Node:
node_run_result=event.node_run_result,
error=event.node_run_result.error,
)
raise NotImplementedError(f"Node {self._node_id} does not support status {event.node_run_result.status}")
case _:
raise NotImplementedError(
f"Node {self._node_id} does not support status {event.node_run_result.status}"
)
def _handle_agent_log_event(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -412,7 +401,8 @@ class Node:
metadata=event.metadata,
)
def _handle_loop_started_event(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -424,7 +414,8 @@ class Node:
predecessor_node_id=event.predecessor_node_id,
)
def _handle_loop_next_event(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
@_dispatch.register
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -434,7 +425,8 @@ class Node:
pre_loop_output=event.pre_loop_output,
)
def _handle_loop_succeeded_event(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
@_dispatch.register
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -447,7 +439,8 @@ class Node:
steps=event.steps,
)
def _handle_loop_failed_event(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
@_dispatch.register
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -461,7 +454,8 @@ class Node:
error=event.error,
)
def _handle_iteration_started_event(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
@_dispatch.register
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -473,7 +467,8 @@ class Node:
predecessor_node_id=event.predecessor_node_id,
)
def _handle_iteration_next_event(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
@_dispatch.register
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -483,7 +478,8 @@ class Node:
pre_iteration_output=event.pre_iteration_output,
)
def _handle_iteration_succeeded_event(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
@_dispatch.register
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -496,7 +492,8 @@ class Node:
steps=event.steps,
)
def _handle_iteration_failed_event(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
@_dispatch.register
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -510,7 +507,8 @@ class Node:
error=event.error,
)
def _handle_run_retriever_resource_event(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
id=self._node_execution_id,
node_id=self._node_id,

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any
from typing import Any, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -49,7 +49,7 @@ class CodeNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: dict | None = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters.
@ -57,7 +57,7 @@ class CodeNode(Node):
"""
code_language = CodeLanguage.PYTHON3
if filters:
code_language = filters.get("code_language", CodeLanguage.PYTHON3)
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal
from typing import Annotated, Literal, Self
from pydantic import AfterValidator, BaseModel
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, "CodeNodeData.Output"] | None = None
children: dict[str, Self] | None = None
class Dependency(BaseModel):
name: str

View File

@ -58,7 +58,7 @@ class HttpRequestNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: dict[str, Any] | None = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "http-request",
"config": {

View File

@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any | None = None
current_output: Any = None
class MetaData(BaseIterationState.MetaData):
"""
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
iterator_length: int
def get_last_output(self) -> Any | None:
def get_last_output(self) -> Any:
"""
Get last output.
"""
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
return self.outputs[-1]
return None
def get_current_output(self) -> Any | None:
def get_current_output(self) -> Any:
"""
Get current output.
"""

View File

@ -1,7 +1,10 @@
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, NewType, cast
from typing_extensions import TypeIs
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
@ -23,6 +26,7 @@ from core.workflow.node_events import (
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
NodeEventBase,
NodeRunResult,
StreamCompletedEvent,
)
@ -45,6 +49,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(Node):
"""
@ -77,7 +83,7 @@ class IterationNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: dict | None = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "iteration",
"config": {
@ -91,44 +97,21 @@ class IterationNode(Node):
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore
variable = self._get_iterator_variable()
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
if self._is_empty_iteration(variable):
yield from self._handle_empty_iteration(variable)
return
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
iterator_list_value = self._validate_and_get_iterator_list(variable)
inputs = {"iterator_selector": iterator_list_value}
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
self._validate_start_node()
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
outputs: list[Any] = []
outputs: list[object] = []
yield IterationStartedEvent(
start_at=started_at,
@ -137,6 +120,86 @@ class IterationNode(Node):
)
try:
yield from self._execute_iterations(
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
)
yield from self._handle_iteration_success(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
)
except IterationNodeError as e:
yield from self._handle_iteration_failure(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
error=e,
)
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
return variable
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
return isinstance(variable, NoneSegment) or len(variable.value) == 0
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
# Try our best to preserve the type information.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
return cast(list[object], iterator_list_value)
def _validate_start_node(self) -> None:
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
def _execute_iterations(
self,
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
# Parallel mode execution
yield from self._execute_parallel_iterations(
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
)
else:
# Sequential mode execution
for index, item in enumerate(iterator_list_value):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
yield IterationNextEvent(index=index)
@ -154,45 +217,146 @@ class IterationNode(Node):
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
def _execute_parallel_iterations(
self,
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value))
# Yield final success event
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
},
# Determine the number of parallel workers
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
for index, item in enumerate(iterator_list_value):
yield IterationNextEvent(index=index)
future = executor.submit(
self._execute_single_iteration_parallel,
index=index,
item=item,
)
)
except IterationNodeError as e:
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
future_to_index[future] = index
# Process completed iterations as they finish
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
result = future.result()
iter_start_at, events, output_value, tokens_used = result
# Update outputs at the correct index
outputs[index] = output_value
# Yield all events from this iteration
yield from events
# Update tokens and timing
self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
except Exception as e:
# Handle errors based on error_handle_mode
match self._node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
# Cancel remaining futures and re-raise
for f in future_to_index:
if f != future:
f.cancel()
raise IterationNodeError(str(e))
case ErrorHandleMode.CONTINUE_ON_ERROR:
outputs[index] = None
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[index] = None # Will be filtered later
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[:] = [output for output in outputs if output is not None]
def _execute_single_iteration_parallel(
self,
index: int,
item: object,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
"""Execute a single iteration in parallel mode and return results."""
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
graph_engine = self._create_graph_engine(index, item)
# Collect events instead of yielding them directly
for event in self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs_temp,
graph_engine=graph_engine,
):
events.append(event)
# Get the output value from the temporary outputs list
output_value = outputs_temp[0] if outputs_temp else None
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
def _handle_iteration_success(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
) -> Generator[NodeEventBase, None, None]:
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
# Yield final success event
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(e),
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
def _handle_iteration_failure(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(error),
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(error),
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -305,9 +469,9 @@ class IterationNode(Node):
self,
*,
variable_pool: VariablePool,
outputs: list,
outputs: list[object],
graph_engine: "GraphEngine",
) -> Generator[Union[GraphNodeEventBase, StreamCompletedEvent], None, None]:
) -> Generator[GraphNodeEventBase, None, None]:
rst = graph_engine.run()
# get current iteration index
index_variable = variable_pool.get([self._node_id, "index"])
@ -338,7 +502,7 @@ class IterationNode(Node):
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
return
def _create_graph_engine(self, index: int, item: Any):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
@ -387,18 +551,9 @@ class IterationNode(Node):
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=self.graph_config,
graph_runtime_state=graph_runtime_state_copy,
max_execution_steps=10000, # Use default or config value
max_execution_time=600, # Use default or config value
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)

View File

@ -4,7 +4,7 @@ import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
@ -568,7 +568,7 @@ class KnowledgeRetrievalNode(Node):
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list[Any]
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return filters

View File

@ -959,7 +959,7 @@ class LLMNode(Node):
return variable_mapping
@classmethod
def get_default_config(cls, filters: dict | None = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "llm",
"config": {

View File

@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal, Optional
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
@ -41,7 +41,7 @@ class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: dict[str, Any] = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@ -74,7 +74,7 @@ class LoopState(BaseLoopState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any | None = None
current_output: Any = None
class MetaData(BaseLoopState.MetaData):
"""
@ -83,7 +83,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Any | None:
def get_last_output(self) -> Any:
"""
Get last output.
"""
@ -91,7 +91,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Any | None:
def get_current_output(self) -> Any:
"""
Get current output.
"""

View File

@ -4,7 +4,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
@ -444,18 +443,9 @@ class LoopNode(Node):
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
graph_runtime_state=graph_runtime_state_copy,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)

View File

@ -118,7 +118,7 @@ class ParameterExtractorNode(Node):
_model_config: ModelConfigWithCredentialsEntity | None = None
@classmethod
def get_default_config(cls, filters: dict | None = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"model": {
"prompt_templates": {

View File

@ -271,7 +271,7 @@ class QuestionClassifierNode(Node):
return variable_mapping
@classmethod
def get_default_config(cls, filters: dict | None = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters (not used in this implementation).

View File

@ -39,7 +39,7 @@ class TemplateTransformNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: dict | None = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters.

View File

@ -19,7 +19,7 @@ from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
@ -55,7 +55,7 @@ class ToolNode(Node):
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Run the tool node
"""

View File

@ -18,7 +18,7 @@ class VariableOperationItem(BaseModel):
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
# 3. During the variable updating procedure: The `value` field is reassigned to hold
# the resolved actual value that will be applied to the target variable.
value: Any | None = None
value: Any = None
class VariableAssignerNodeData(BaseNodeData):