This commit is contained in:
takatost
2024-07-17 11:26:33 +08:00
parent 16e2d00157
commit cc96acdae3
4 changed files with 199 additions and 116 deletions

View File

@ -43,6 +43,26 @@ class RouteNodeState(BaseModel):
paused_by: Optional[str] = None
"""paused by"""
def set_finished(self, run_result: NodeRunResult) -> None:
"""
Node finished
:param run_result: run result
"""
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
self.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(
@ -87,29 +107,3 @@ class RuntimeRouteState(BaseModel):
"""
return [self.node_state_mapping[target_state_id]
for target_state_id in self.routes.get(source_node_state_id, [])]
def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None:
"""
Node finished
:param node_state_id: route node state id
:param run_result: run result
"""
if node_state_id not in self.node_state_mapping:
raise Exception(f"Route state {node_state_id} not found")
route = self.node_state_mapping[node_state_id]
if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {node_state_id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
route.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
route.status = RouteNodeState.Status.FAILED
route.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
route.node_run_result = run_result
route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)

View File

@ -3,14 +3,14 @@ import queue
import time
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import Optional
from flask import Flask, current_app
from uritemplate.variable import VariableValue
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeType, UserFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
@ -89,7 +89,7 @@ class GraphEngine:
# trigger graph run success event
yield GraphRunSucceededEvent()
except (GraphRunFailedError, NodeRunFailedError) as e:
except GraphRunFailedError as e:
yield GraphRunFailedEvent(reason=e.error)
return
except Exception as e:
@ -112,7 +112,7 @@ class GraphEngine:
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
# init route node state
route_node_state = self.graph_runtime_state.create_node_state(
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
node_id=next_node_id
)
@ -128,13 +128,13 @@ class GraphEngine:
# append route
if previous_route_node_state:
if previous_route_node_state.id not in self.graph_runtime_state.node_run_state.routes:
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id] = []
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id].append(
route_node_state.id
self.graph_runtime_state.node_run_state.add_route(
source_node_state_id=previous_route_node_state.id,
target_node_state_id=route_node_state.id
)
except Exception as e:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=in_parallel_id
@ -181,9 +181,9 @@ class GraphEngine:
next_node_id = final_node_id
else:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.')
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
@ -199,18 +199,27 @@ class GraphEngine:
self._run_parallel_node,
flask_app=current_app._get_current_object(), # type: ignore
parallel_id=parallel_id,
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel
parallel_start_node_id=edge.target_node_id,
q=q
))
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
# TODO tag event with parallel id
yield event
if isinstance(event, GraphRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(edge_mappings):
break
continue
elif isinstance(event, GraphRunFailedEvent):
raise GraphRunFailedError(event.reason)
else:
yield event
except queue.Empty:
continue
@ -246,19 +255,15 @@ class GraphEngine:
for item in generator:
q.put(item)
if isinstance(item, NodeRunFailedEvent):
q.put(GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.'))
return
# trigger graph run success event
q.put(GraphRunSucceededEvent())
except (GraphRunFailedError, NodeRunFailedError) as e:
except GraphRunFailedError as e:
q.put(GraphRunFailedEvent(reason=e.error))
except Exception as e:
logger.exception("Unknown Error when generating in parallel")
q.put(GraphRunFailedEvent(reason=str(e)))
finally:
q.put(None)
db.session.remove()
def _run_node(self,
@ -268,17 +273,35 @@ class GraphEngine:
"""
Run node
"""
# trigger node run start event
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
# get node config
node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} config not found.')
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} config not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
return
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
return
# init workflow run state
node_instance = node_cls( # type: ignore
@ -289,12 +312,6 @@ class GraphEngine:
previous_node_id=previous_node_id
)
# trigger node run start event
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
db.session.close()
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
@ -307,13 +324,7 @@ class GraphEngine:
for item in generator:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.status = RouteNodeState.Status.SUCCESS \
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED \
else RouteNodeState.Status.FAILED
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
route_node_state.node_run_result = run_result
route_node_state.failed_reason = run_result.error \
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
@ -321,10 +332,27 @@ class GraphEngine:
route_node_state=route_node_state
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)
)
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
yield NodeRunSucceededEvent(
parallel_id=parallel_id,
route_node_state=route_node_state
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
@ -340,8 +368,10 @@ class GraphEngine:
retriever_resources=item.retriever_resources,
context=item.context
)
except GenerateTaskStoppedException as e:
except GenerateTaskStoppedException:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
@ -353,6 +383,34 @@ class GraphEngine:
finally:
db.session.close()
def _append_variables_recursively(self,
node_id: str,
variable_key_list: list[str],
variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
self.graph_runtime_state.variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value
)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
node_id=node_id,
variable_key_list=new_key_list,
variable_value=value
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
@ -366,8 +424,3 @@ class GraphEngine:
class GraphRunFailedError(Exception):
def __init__(self, error: str):
self.error = error
class NodeRunFailedError(Exception):
def __init__(self, error: str):
self.error = error