mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge remote-tracking branch 'upstream/feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
@ -33,7 +33,13 @@ from core.workflow.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import AgentLogEvent, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@ -93,7 +99,7 @@ class AgentNode(Node):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
try:
|
||||
@ -482,7 +488,7 @@ class AgentNode(Node):
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator:
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from typing import Any, ClassVar
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
@ -45,11 +46,6 @@ from models.enums import UserFrom
|
||||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -88,14 +84,14 @@ class Node:
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> "NodeRunResult | Generator[GraphNodeEventBase, None, None]":
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self) -> "Generator[GraphNodeEventBase, None, None]":
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
# Generate a single node execution ID to use for all events
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
@ -151,12 +147,14 @@ class Node:
|
||||
|
||||
# Handle event stream
|
||||
for event in result:
|
||||
if isinstance(event, NodeEventBase):
|
||||
event = self._convert_node_event_to_graph_node_event(event)
|
||||
|
||||
if not event.in_iteration_id and not event.in_loop_id:
|
||||
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
|
||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
yield self._dispatch(event)
|
||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event.id = self._node_execution_id
|
||||
yield event
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
@ -249,7 +247,7 @@ class Node:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
@ -270,7 +268,7 @@ class Node:
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> Optional["ErrorStrategy"]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
|
||||
@ -301,7 +299,7 @@ class Node:
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> Optional["ErrorStrategy"]:
|
||||
def error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._get_error_strategy()
|
||||
|
||||
@ -344,29 +342,15 @@ class Node:
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
)
|
||||
raise Exception(f"result status {result.status} not supported")
|
||||
case _:
|
||||
raise Exception(f"result status {result.status} not supported")
|
||||
|
||||
def _convert_node_event_to_graph_node_event(self, event: NodeEventBase) -> GraphNodeEventBase:
|
||||
handler_maps: dict[type[NodeEventBase], Callable[[Any], GraphNodeEventBase]] = {
|
||||
StreamChunkEvent: self._handle_stream_chunk_event,
|
||||
StreamCompletedEvent: self._handle_stream_completed_event,
|
||||
AgentLogEvent: self._handle_agent_log_event,
|
||||
LoopStartedEvent: self._handle_loop_started_event,
|
||||
LoopNextEvent: self._handle_loop_next_event,
|
||||
LoopSucceededEvent: self._handle_loop_succeeded_event,
|
||||
LoopFailedEvent: self._handle_loop_failed_event,
|
||||
IterationStartedEvent: self._handle_iteration_started_event,
|
||||
IterationNextEvent: self._handle_iteration_next_event,
|
||||
IterationSucceededEvent: self._handle_iteration_succeeded_event,
|
||||
IterationFailedEvent: self._handle_iteration_failed_event,
|
||||
RunRetrieverResourceEvent: self._handle_run_retriever_resource_event,
|
||||
}
|
||||
handler = handler_maps.get(type(event))
|
||||
if not handler:
|
||||
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
|
||||
return handler(event)
|
||||
@singledispatchmethod
|
||||
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
|
||||
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
|
||||
|
||||
def _handle_stream_chunk_event(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -376,7 +360,8 @@ class Node:
|
||||
is_final=event.is_final,
|
||||
)
|
||||
|
||||
def _handle_stream_completed_event(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
@ -395,9 +380,13 @@ class Node:
|
||||
node_run_result=event.node_run_result,
|
||||
error=event.node_run_result.error,
|
||||
)
|
||||
raise NotImplementedError(f"Node {self._node_id} does not support status {event.node_run_result.status}")
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f"Node {self._node_id} does not support status {event.node_run_result.status}"
|
||||
)
|
||||
|
||||
def _handle_agent_log_event(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
return NodeRunAgentLogEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -412,7 +401,8 @@ class Node:
|
||||
metadata=event.metadata,
|
||||
)
|
||||
|
||||
def _handle_loop_started_event(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
return NodeRunLoopStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -424,7 +414,8 @@ class Node:
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
def _handle_loop_next_event(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||
return NodeRunLoopNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -434,7 +425,8 @@ class Node:
|
||||
pre_loop_output=event.pre_loop_output,
|
||||
)
|
||||
|
||||
def _handle_loop_succeeded_event(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||
return NodeRunLoopSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -447,7 +439,8 @@ class Node:
|
||||
steps=event.steps,
|
||||
)
|
||||
|
||||
def _handle_loop_failed_event(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||
return NodeRunLoopFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -461,7 +454,8 @@ class Node:
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_iteration_started_event(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||
return NodeRunIterationStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -473,7 +467,8 @@ class Node:
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
def _handle_iteration_next_event(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||
return NodeRunIterationNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -483,7 +478,8 @@ class Node:
|
||||
pre_iteration_output=event.pre_iteration_output,
|
||||
)
|
||||
|
||||
def _handle_iteration_succeeded_event(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||
return NodeRunIterationSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -496,7 +492,8 @@ class Node:
|
||||
steps=event.steps,
|
||||
)
|
||||
|
||||
def _handle_iteration_failed_event(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||
return NodeRunIterationFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -510,7 +507,8 @@ class Node:
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_run_retriever_resource_event(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
@_dispatch.register
|
||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
return NodeRunRetrieverResourceEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
@ -49,7 +49,7 @@ class CodeNode(Node):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
@ -57,7 +57,7 @@ class CodeNode(Node):
|
||||
"""
|
||||
code_language = CodeLanguage.PYTHON3
|
||||
if filters:
|
||||
code_language = filters.get("code_language", CodeLanguage.PYTHON3)
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
||||
children: dict[str, Self] | None = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
|
||||
@ -58,7 +58,7 @@ class HttpRequestNode(Node):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict[str, Any] | None = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
|
||||
@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Any | None = None
|
||||
current_output: Any = None
|
||||
|
||||
class MetaData(BaseIterationState.MetaData):
|
||||
"""
|
||||
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
|
||||
|
||||
iterator_length: int
|
||||
|
||||
def get_last_output(self) -> Any | None:
|
||||
def get_last_output(self) -> Any:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Any | None:
|
||||
def get_current_output(self) -> Any:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
@ -23,6 +26,7 @@ from core.workflow.node_events import (
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
IterationSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
@ -45,6 +49,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||
|
||||
|
||||
class IterationNode(Node):
|
||||
"""
|
||||
@ -77,7 +83,7 @@ class IterationNode(Node):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
@ -91,44 +97,21 @@ class IterationNode(Node):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore
|
||||
variable = self._get_iterator_variable()
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
|
||||
# Try our best to preserve the type informat.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
)
|
||||
)
|
||||
if self._is_empty_iteration(variable):
|
||||
yield from self._handle_empty_iteration(variable)
|
||||
return
|
||||
|
||||
iterator_list_value = variable.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
iterator_list_value = self._validate_and_get_iterator_list(variable)
|
||||
inputs = {"iterator_selector": iterator_list_value}
|
||||
|
||||
if not self._node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
self._validate_start_node()
|
||||
|
||||
started_at = naive_utc_now()
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[Any] = []
|
||||
outputs: list[object] = []
|
||||
|
||||
yield IterationStartedEvent(
|
||||
start_at=started_at,
|
||||
@ -137,6 +120,86 @@ class IterationNode(Node):
|
||||
)
|
||||
|
||||
try:
|
||||
yield from self._execute_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
|
||||
yield from self._handle_iteration_success(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
yield from self._handle_iteration_failure(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
error=e,
|
||||
)
|
||||
|
||||
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
return variable
|
||||
|
||||
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
|
||||
return isinstance(variable, NoneSegment) or len(variable.value) == 0
|
||||
|
||||
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
|
||||
# Try our best to preserve the type information.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
)
|
||||
)
|
||||
|
||||
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
|
||||
iterator_list_value = variable.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
return cast(list[object], iterator_list_value)
|
||||
|
||||
def _validate_start_node(self) -> None:
|
||||
if not self._node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
|
||||
def _execute_iterations(
|
||||
self,
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
if self._node_data.is_parallel:
|
||||
# Parallel mode execution
|
||||
yield from self._execute_parallel_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
else:
|
||||
# Sequential mode execution
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
yield IterationNextEvent(index=index)
|
||||
@ -154,45 +217,146 @@ class IterationNode(Node):
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
yield IterationSucceededEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
)
|
||||
def _execute_parallel_iterations(
|
||||
self,
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
# Initialize outputs list with None values to maintain order
|
||||
outputs.extend([None] * len(iterator_list_value))
|
||||
|
||||
# Yield final success event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
},
|
||||
# Determine the number of parallel workers
|
||||
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
yield IterationNextEvent(index=index)
|
||||
future = executor.submit(
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
)
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
yield IterationFailedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
future_to_index[future] = index
|
||||
|
||||
# Process completed iterations as they finish
|
||||
for future in as_completed(future_to_index):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
iter_start_at, events, output_value, tokens_used = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
|
||||
# Yield all events from this iteration
|
||||
yield from events
|
||||
|
||||
# Update tokens and timing
|
||||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors based on error_handle_mode
|
||||
match self._node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
# Cancel remaining futures and re-raise
|
||||
for f in future_to_index:
|
||||
if f != future:
|
||||
f.cancel()
|
||||
raise IterationNodeError(str(e))
|
||||
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
outputs[index] = None
|
||||
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[index] = None # Will be filtered later
|
||||
|
||||
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[:] = [output for output in outputs if output is not None]
|
||||
|
||||
def _execute_single_iteration_parallel(
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Collect events instead of yielding them directly
|
||||
for event in self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs_temp,
|
||||
graph_engine=graph_engine,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
started_at: datetime,
|
||||
inputs: dict[str, Sequence[object]],
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
yield IterationSucceededEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
)
|
||||
|
||||
# Yield final success event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_iteration_failure(
|
||||
self,
|
||||
started_at: datetime,
|
||||
inputs: dict[str, Sequence[object]],
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
error: IterationNodeError,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
yield IterationFailedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
error=str(error),
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@ -305,9 +469,9 @@ class IterationNode(Node):
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
outputs: list,
|
||||
outputs: list[object],
|
||||
graph_engine: "GraphEngine",
|
||||
) -> Generator[Union[GraphNodeEventBase, StreamCompletedEvent], None, None]:
|
||||
) -> Generator[GraphNodeEventBase, None, None]:
|
||||
rst = graph_engine.run()
|
||||
# get current iteration index
|
||||
index_variable = variable_pool.get([self._node_id, "index"])
|
||||
@ -338,7 +502,7 @@ class IterationNode(Node):
|
||||
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
return
|
||||
|
||||
def _create_graph_engine(self, index: int, item: Any):
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph import Graph
|
||||
@ -387,18 +551,9 @@ class IterationNode(Node):
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=iteration_graph,
|
||||
graph_config=self.graph_config,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
max_execution_steps=10000, # Use default or config value
|
||||
max_execution_time=600, # Use default or config value
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
@ -568,7 +568,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list[Any]
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
) -> list[Any]:
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
||||
@ -959,7 +959,7 @@ class LLMNode(Node):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "llm",
|
||||
"config": {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
@ -41,7 +41,7 @@ class LoopNodeData(BaseLoopNodeData):
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
|
||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("outputs", mode="before")
|
||||
@ -74,7 +74,7 @@ class LoopState(BaseLoopState):
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Any | None = None
|
||||
current_output: Any = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
@ -83,7 +83,7 @@ class LoopState(BaseLoopState):
|
||||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> Any | None:
|
||||
def get_last_output(self) -> Any:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@ -91,7 +91,7 @@ class LoopState(BaseLoopState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Any | None:
|
||||
def get_current_output(self) -> Any:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@ -4,7 +4,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
@ -444,18 +443,9 @@ class LoopNode(Node):
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=loop_graph,
|
||||
graph_config=self.graph_config,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
|
||||
@ -118,7 +118,7 @@ class ParameterExtractorNode(Node):
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"model": {
|
||||
"prompt_templates": {
|
||||
|
||||
@ -271,7 +271,7 @@ class QuestionClassifierNode(Node):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters (not used in this implementation).
|
||||
|
||||
@ -39,7 +39,7 @@ class TemplateTransformNode(Node):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None):
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
|
||||
@ -19,7 +19,7 @@ from core.workflow.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
@ -55,7 +55,7 @@ class ToolNode(Node):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
|
||||
@ -18,7 +18,7 @@ class VariableOperationItem(BaseModel):
|
||||
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
|
||||
# 3. During the variable updating procedure: The `value` field is reassigned to hold
|
||||
# the resolved actual value that will be applied to the target variable.
|
||||
value: Any | None = None
|
||||
value: Any = None
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
|
||||
Reference in New Issue
Block a user