mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
fix iteration start node
This commit is contained in:
@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
|
||||
desc: Optional[str] = None
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
|
||||
@ -28,6 +28,7 @@ class NodeType(Enum):
|
||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
LOOP = 'loop'
|
||||
ITERATION = 'iteration'
|
||||
ITERATION_START = 'iteration-start' # fake start node for iteration
|
||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
|
||||
class IterationStartNodeData(BaseNodeData):
|
||||
"""
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
pass
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
"""
|
||||
Iteration State.
|
||||
|
||||
@ -50,9 +50,39 @@ class IterationNode(BaseNode):
|
||||
"iterator_selector": iterator_list_value
|
||||
}
|
||||
|
||||
root_node_id = self.node_data.start_node_id
|
||||
graph_config = self.graph_config
|
||||
|
||||
# find nodes in current iteration and donot have source and have have start_node_in_iteration flag
|
||||
# these nodes are the start nodes of the iteration (in version of parallel support)
|
||||
start_node_ids = []
|
||||
for node_config in graph_config['nodes']:
|
||||
if (
|
||||
node_config.get('data', {}).get('iteration_id')
|
||||
and node_config.get('data', {}).get('iteration_id') == self.node_id
|
||||
and not node_config.get('source')
|
||||
and node_config.get('data', {}).get('start_node_in_iteration', False)
|
||||
):
|
||||
start_node_ids.append(node_config.get('id'))
|
||||
|
||||
if len(start_node_ids) > 1:
|
||||
# add new fake iteration start node that connect to all start nodes
|
||||
root_node_id = f"{self.node_id}-start"
|
||||
graph_config['nodes'].append({
|
||||
"id": root_node_id,
|
||||
"data": {
|
||||
"title": "iteration start",
|
||||
"type": NodeType.ITERATION_START.value,
|
||||
}
|
||||
})
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
graph_config['edges'].append({
|
||||
"source": root_node_id,
|
||||
"target": start_node_id
|
||||
})
|
||||
else:
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
@ -156,6 +186,9 @@ class IterationNode(BaseNode):
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
@ -180,7 +213,11 @@ class IterationNode(BaseNode):
|
||||
variable_pool.remove_node(node_id)
|
||||
|
||||
# move to next iteration
|
||||
next_index = variable_pool.get_any([self.node_id, 'index']) + 1
|
||||
current_index = variable_pool.get([self.node_id, 'index'])
|
||||
if current_index is None:
|
||||
raise ValueError(f'iteration {self.node_id} current index not found')
|
||||
|
||||
next_index = int(current_index.to_object()) + 1
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
next_index
|
||||
@ -229,6 +266,7 @@ class IterationNode(BaseNode):
|
||||
)
|
||||
break
|
||||
else:
|
||||
event = cast(InNodeEvent, event)
|
||||
yield event
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
|
||||
40
api/core/workflow/nodes/iteration/iteration_start_node.py
Normal file
40
api/core/workflow/nodes/iteration/iteration_start_node.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
@ -5,6 +5,7 @@ from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
@ -30,6 +31,7 @@ node_classes = {
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.ITERATION_START: IterationStartNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user