fix iteration

This commit is contained in:
takatost
2024-07-26 02:43:40 +08:00
parent ae351bd40e
commit a31feacf28
7 changed files with 283 additions and 27 deletions

View File

@ -104,9 +104,14 @@ class GraphEngine:
)
for item in generator:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')
try:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')
return
except Exception as e:
logger.exception(f"Graph run failed: {str(e)}")
yield GraphRunFailedEvent(reason=str(e))
return
# trigger graph run success event
@ -115,6 +120,7 @@ class GraphEngine:
yield GraphRunFailedEvent(reason=e.error)
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(reason=str(e))
raise e
@ -182,7 +188,22 @@ class GraphEngine:
break
if len(edge_mappings) == 1:
next_node_id = edge_mappings[0].target_node_id
edge = edge_mappings[0]
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
previous_route_node_state=previous_route_node_state,
target_node_id=edge.target_node_id,
)
if not result:
break
next_node_id = edge.target_node_id
else:
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results

View File

@ -6,6 +6,7 @@ from core.file.file_obj import FileVar
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
@ -31,7 +32,12 @@ class AnswerStreamProcessor:
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStreamChunkEvent):
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
@ -99,6 +105,9 @@ class AnswerStreamProcessor:
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
@ -107,6 +116,9 @@ class AnswerStreamProcessor:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:

View File

@ -4,6 +4,7 @@ from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
@ -26,7 +27,12 @@ class EndStreamProcessor:
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStreamChunkEvent):
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
@ -87,6 +93,9 @@ class EndStreamProcessor:
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
@ -95,6 +104,9 @@ class EndStreamProcessor:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:

View File

@ -5,15 +5,13 @@ from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import GraphRunFailedEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.event import BaseGraphEvent, GraphRunFailedEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.workflow_entry import WorkflowRunFailedError
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
@ -74,7 +72,7 @@ class IterationNode(BaseNode):
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=len(iterator_list_value)
value=str(len(iterator_list_value))
)
]
)
@ -93,6 +91,7 @@ class IterationNode(BaseNode):
)
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
@ -114,10 +113,11 @@ class IterationNode(BaseNode):
rst = graph_engine.run()
outputs: list[Any] = []
for event in rst:
yield event
if isinstance(event, NodeRunSucceededEvent):
yield event
# handle iteration run result
if event.node_id in iteration_leaf_node_ids:
if event.route_node_state.node_id in iteration_leaf_node_ids:
# append to iteration output variable list
outputs.append(variable_pool.get_any(self.node_data.output_selector))
@ -132,13 +132,23 @@ class IterationNode(BaseNode):
next_index
)
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[next_index]
)
elif isinstance(event, GraphRunFailedEvent):
if next_index < len(iterator_list_value):
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[next_index]
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
raise WorkflowRunFailedError(event.reason)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.reason,
)
)
break
else:
yield event
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -148,14 +158,6 @@ class IterationNode(BaseNode):
}
)
)
except WorkflowRunFailedError as e:
# iteration run failed
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
except Exception as e:
# iteration run failed
logger.exception("Iteration run failed")