graph engine implement

This commit is contained in:
takatost
2024-07-15 23:40:02 +08:00
parent 821e09b259
commit 00fb23d0c9
11 changed files with 511 additions and 270 deletions

View File

@ -0,0 +1,116 @@
from typing import Optional
from pydantic import BaseModel, Field, model_validator
from core.workflow.entities.node_entities import NodeRunResult
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunBackToRootEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
pass
class GraphRunFailedEvent(BaseGraphEvent):
reason: str = Field(..., description="failed reason")
###########################################
# Node Events
###########################################
class BaseNodeEvent(GraphEngineEvent):
node_id: str = Field(..., description="node id")
parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel")
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
class NodeRunStartedEvent(BaseNodeEvent):
pass
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
class NodeRunSucceededEvent(BaseNodeEvent):
run_result: NodeRunResult = Field(..., description="run result")
class NodeRunFailedEvent(BaseNodeEvent):
run_result: NodeRunResult = Field(..., description="run result")
reason: str = Field("", description="failed reason")
@model_validator(mode='before')
def init_reason(cls, values: dict) -> dict:
if not values.get("reason"):
values["reason"] = values.get("run_result").error or "Unknown error"
return values
###########################################
# Parallel Events
###########################################
class BaseParallelEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
class ParallelRunStartedEvent(BaseParallelEvent):
pass
class ParallelRunSucceededEvent(BaseParallelEvent):
pass
class ParallelRunFailedEvent(BaseParallelEvent):
reason: str = Field(..., description="failed reason")
###########################################
# Iteration Events
###########################################
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration id")
class IterationRunStartedEvent(BaseIterationEvent):
pass
class IterationRunSucceededEvent(BaseIterationEvent):
pass
class IterationRunFailedEvent(BaseIterationEvent):
reason: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelEvent | BaseIterationEvent

View File

@ -8,48 +8,37 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition
class GraphEdge(BaseModel):
source_node_id: str
"""source node id"""
target_node_id: str
"""target node id"""
run_condition: Optional[RunCondition] = None
"""condition to run the edge"""
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = Field(None, description="run condition")
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random uuid parallel id"""
start_from_node_id: str
"""start from node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if exists"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = Field(None, description="parent parallel id")
end_to_node_id: Optional[str] = Field(None, description="end to node id")
class Graph(BaseModel):
root_node_id: str
"""root node id of the graph"""
node_ids: list[str] = Field(default_factory=list)
"""graph node ids"""
node_id_config_mapping: dict[str, dict] = Field(default_factory=list)
"""node configs mapping (node id: node config)"""
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
"""graph edge mapping (source node id: edges)"""
parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict)
"""graph parallel mapping (parallel id: parallel)"""
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
"""graph node parallel mapping (node id: parallel id)"""
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=list,
description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="graph edge mapping (source node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict,
description="graph parallel mapping (parallel id: parallel)"
)
node_parallel_mapping: dict[str, str] = Field(
default_factory=dict,
description="graph node parallel mapping (node id: parallel id)"
)
@classmethod
def init(cls,

View File

@ -0,0 +1,17 @@
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")

View File

@ -1,25 +1,14 @@
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
from core.workflow.nodes.base_node import UserFrom
class GraphRuntimeState(BaseModel):
# init params
tenant_id: str
app_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
call_depth: int
variable_pool: VariablePool = Field(..., description="variable pool")
variable_pool: VariablePool
start_at: float = Field(..., description="start time")
total_tokens: int = Field(0, description="total tokens")
node_run_steps: int = Field(0, description="node run steps")
start_at: float
total_tokens: int = 0
node_run_steps: int = 0
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState)
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state")

View File

@ -2,7 +2,7 @@ import logging
import queue
import time
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, cast
from flask import Flask, current_app
@ -13,10 +13,20 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.base_node import UserFrom, node_classes
from extensions.ext_database import db
from models.workflow import WorkflowType
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
logger = logging.getLogger(__name__)
@ -25,6 +35,8 @@ logger = logging.getLogger(__name__)
class GraphEngine:
def __init__(self, tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
@ -33,13 +45,18 @@ class GraphEngine:
variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None:
self.graph = graph
self.graph_runtime_state = GraphRuntimeState(
self.init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
call_depth=call_depth
)
self.graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter()
)
@ -55,31 +72,31 @@ class GraphEngine:
# TODO convert generator to result
pass
def run(self) -> Generator:
# TODO trigger graph run start event
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
try:
# TODO run graph
rst = self._run(start_node_id=self.graph.root_node_id)
except GraphRunFailedError as e:
# TODO self._graph_run_failed(
# error=e.error,
# callbacks=callbacks
# )
pass
# run graph
generator = self._run(start_node_id=self.graph.root_node_id)
for item in generator:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(reason=item.reason)
return
# trigger graph run success event
yield GraphRunSucceededEvent()
except (GraphRunFailedError, NodeRunFailedError) as e:
yield GraphRunFailedEvent(reason=e.error)
return
except Exception as e:
# TODO self._workflow_run_failed(
# error=str(e),
# callbacks=callbacks
# )
pass
yield GraphRunFailedEvent(reason=str(e))
return
# TODO trigger graph run success event
yield rst
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None):
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
next_node_id = start_node_id
previous_node_id = None
while True:
# max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
@ -92,10 +109,18 @@ class GraphEngine:
):
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
# run node TODO generator
yield from self._run_node(node_id=next_node_id)
try:
# run node
yield from self._run_node(
node_id=next_node_id,
previous_node_id=previous_node_id,
parallel_id=in_parallel_id
)
except Exception as e:
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
return
# todo if failed, break
previous_node_id = next_node_id
# get next node ids
edge_mappings = self.graph.edge_mapping.get(next_node_id)
@ -135,11 +160,11 @@ class GraphEngine:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
if not parallel_id:
raise GraphRunFailedError('Node related parallel not found.')
raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError('Parallel not found.')
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
@ -149,8 +174,9 @@ class GraphEngine:
for edge in edge_mappings:
futures.append(thread_pool.submit(
self._run_parallel_node,
flask_app=current_app._get_current_object(),
parallel_start_node_id=edge.source_node_id,
flask_app=current_app._get_current_object(), # type: ignore
parallel_id=parallel_id,
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel
q=q
))
@ -165,8 +191,9 @@ class GraphEngine:
except queue.Empty:
continue
for future in as_completed(futures):
future.result()
# not necessary
# for future in as_completed(futures):
# future.result()
# get final node id
final_node_id = parallel.end_to_node_id
@ -178,48 +205,61 @@ class GraphEngine:
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
break
def _run_parallel_node(self, flask_app: Flask, parallel_start_node_id: str, q: queue.Queue) -> None:
def _run_parallel_node(self,
flask_app: Flask,
parallel_id: str,
parallel_start_node_id: str,
q: queue.Queue) -> None:
"""
Run parallel nodes
"""
with flask_app.app_context():
try:
in_parallel_id = self.graph.node_parallel_mapping.get(parallel_start_node_id)
if not in_parallel_id:
q.put(None)
return
# run node TODO generator
rst = self._run(
# run node
generator = self._run(
start_node_id=parallel_start_node_id,
in_parallel_id=in_parallel_id
in_parallel_id=parallel_id
)
if not rst:
q.put(None)
return
for item in rst:
q.put(item)
q.put(None)
if generator:
for item in generator:
q.put(item)
except Exception:
logger.exception("Unknown Error when generating in parallel")
finally:
q.put(None)
db.session.remove()
def _run_node(self, node_id: str) -> Generator:
def _run_node(self,
node_id: str,
previous_node_id: Optional[str] = None,
parallel_id: Optional[str] = None
) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
"""
# get node config
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError('Node not found.')
raise GraphRunFailedError(f'Node {node_id} config not found.')
# todo convert to specific node
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
# todo trigger node run start event
# init workflow run state
node_instance = node_cls( # type: ignore
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id
)
# trigger node run start event
yield NodeRunStartedEvent(node_id=node_id, parallel_id=parallel_id)
db.session.close()
@ -229,28 +269,25 @@ class GraphEngine:
try:
# run node
rst = node.run(
graph_runtime_state=self.graph_runtime_state,
graph=self.graph,
callbacks=self.callbacks
)
generator = node_instance.run()
yield from rst
yield from generator
# todo record state
# trigger node run success event
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
except GenerateTaskStoppedException as e:
# TODO yield failed
# todo trigger node run failed event
pass
# trigger node run failed event
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return
except Exception as e:
# logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
# TODO yield failed
# todo trigger node run failed event
pass
# todo trigger node run success event
db.session.close()
# todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
# trigger node run failed event
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return
finally:
db.session.close()
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
@ -265,3 +302,8 @@ class GraphEngine:
class GraphRunFailedError(Exception):
def __init__(self, error: str):
self.error = error
class NodeRunFailedError(Exception):
def __init__(self, error: str):
self.error = error