fix iteration start node

This commit is contained in:
takatost
2024-08-22 23:53:44 +08:00
parent d6da7b0336
commit ec4fc784f0
7 changed files with 306 additions and 4 deletions

View File

@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
desc: Optional[str] = None
class BaseIterationNodeData(BaseNodeData):
start_node_id: str
start_node_id: Optional[str] = None
class BaseIterationState(BaseModel):
iteration_node_id: str

View File

@ -28,6 +28,7 @@ class NodeType(Enum):
VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop'
ITERATION = 'iteration'
ITERATION_START = 'iteration-start' # fake start node for iteration
PARAMETER_EXTRACTOR = 'parameter-extractor'
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'

View File

@ -1,6 +1,6 @@
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
class IterationNodeData(BaseIterationNodeData):
@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData):
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
class IterationStartNodeData(BaseNodeData):
"""
Iteration Start Node Data.
"""
pass
class IterationState(BaseIterationState):
"""
Iteration State.

View File

@ -50,9 +50,39 @@ class IterationNode(BaseNode):
"iterator_selector": iterator_list_value
}
root_node_id = self.node_data.start_node_id
graph_config = self.graph_config
# find nodes in current iteration and donot have source and have have start_node_in_iteration flag
# these nodes are the start nodes of the iteration (in version of parallel support)
start_node_ids = []
for node_config in graph_config['nodes']:
if (
node_config.get('data', {}).get('iteration_id')
and node_config.get('data', {}).get('iteration_id') == self.node_id
and not node_config.get('source')
and node_config.get('data', {}).get('start_node_in_iteration', False)
):
start_node_ids.append(node_config.get('id'))
if len(start_node_ids) > 1:
# add new fake iteration start node that connect to all start nodes
root_node_id = f"{self.node_id}-start"
graph_config['nodes'].append({
"id": root_node_id,
"data": {
"title": "iteration start",
"type": NodeType.ITERATION_START.value,
}
})
for start_node_id in start_node_ids:
graph_config['edges'].append({
"source": root_node_id,
"target": start_node_id
})
else:
root_node_id = self.node_data.start_node_id
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
@ -156,6 +186,9 @@ class IterationNode(BaseNode):
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
continue
if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
@ -180,7 +213,11 @@ class IterationNode(BaseNode):
variable_pool.remove_node(node_id)
# move to next iteration
next_index = variable_pool.get_any([self.node_id, 'index']) + 1
current_index = variable_pool.get([self.node_id, 'index'])
if current_index is None:
raise ValueError(f'iteration {self.node_id} current index not found')
next_index = int(current_index.to_object()) + 1
variable_pool.add(
[self.node_id, 'index'],
next_index
@ -229,6 +266,7 @@ class IterationNode(BaseNode):
)
break
else:
event = cast(InNodeEvent, event)
yield event
yield IterationRunSucceededEvent(

View File

@ -0,0 +1,40 @@
import logging
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@ -5,6 +5,7 @@ from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
@ -30,6 +31,7 @@ node_classes = {
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
}