refactor runtime

This commit is contained in:
takatost
2024-07-08 16:29:13 +08:00
parent 1adaf42f9d
commit 0e885a3cae
12 changed files with 410 additions and 998 deletions

View File

@ -3,9 +3,7 @@ from typing import Optional, cast
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
@ -28,33 +26,6 @@ class GraphParallel(BaseModel):
"""parent parallel id if exists"""
class GraphStateRoute(BaseModel):
route_id: str
"""route id"""
node_id: str
"""node id"""
class GraphState(BaseModel):
routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict)
"""graph state routes (source_node_id: routes)"""
variable_pool: VariablePool
"""variable pool"""
node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict)
"""node results in route (node_id: run_result)"""
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""
class Graph(BaseModel):
root_node_id: str
"""root node id of the graph"""
@ -71,19 +42,14 @@ class Graph(BaseModel):
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
"""graph node parallel mapping (node id: parallel id)"""
run_state: GraphState
"""graph run state"""
@classmethod
def init(cls,
graph_config: dict,
variable_pool: VariablePool,
root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph
:param graph_config: graph config
:param variable_pool: variable pool
:param root_node_id: root node id
:return: graph
"""
@ -149,7 +115,7 @@ class Graph(BaseModel):
# fetch root node
if not root_node_id:
# if no root node id, use the START type node as root node
root_node_id = next((node_config for node_config in root_node_configs
root_node_id = next((node_config.get("id") for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
if not root_node_id or root_node_id not in root_node_ids:
@ -178,80 +144,12 @@ class Graph(BaseModel):
root_node_id=root_node_id,
node_ids=node_ids,
edge_mapping=edge_mapping,
run_state=GraphState(
variable_pool=variable_pool
),
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
)
return graph
@classmethod
def _recursively_add_node_ids(cls,
node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
node_id: str) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
)
def next_node_ids(self) -> list[NextGraphNode]:
"""
Get next node ids
"""
# get current node ids in state
if not self.run_state.routes:
return [NextGraphNode(node_id=self.root_node_id)]
route_final_graph_edges: list[GraphEdge] = []
for route in self.run_state.routes[self.root_node_id]:
graph_edges = self.edge_mapping.get(route.node_id)
if not graph_edges:
continue
for edge in graph_edges:
if edge.target_node_id not in self.run_state.routes:
route_final_graph_edges.append(edge)
next_graph_nodes = []
for route_final_graph_edge in route_final_graph_edges:
node_id = route_final_graph_edge.target_node_id
# check condition
if route_final_graph_edge.run_condition:
result = ConditionManager.get_condition_handler(
run_condition=route_final_graph_edge.run_condition
).check(
source_node_id=route_final_graph_edge.source_node_id,
target_node_id=route_final_graph_edge.target_node_id,
graph=self
)
if not result:
continue
parallel = None
if route_final_graph_edge.target_node_id in self.node_parallel_mapping:
parallel = self.parallel_mapping[self.node_parallel_mapping[node_id]]
next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
return next_graph_nodes
def add_extra_edge(self, source_node_id: str,
target_node_id: str,
run_condition: Optional[RunCondition] = None) -> None:
@ -295,6 +193,29 @@ class Graph(BaseModel):
return leaf_node_ids
@classmethod
def _recursively_add_node_ids(cls,
node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
node_id: str) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
)
@classmethod
def _recursively_add_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],

View File

@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
from core.workflow.nodes.base_node import UserFrom
class GraphRuntimeState(BaseModel):
# init params
tenant_id: str
app_id: str
user_id: str
@ -17,10 +17,10 @@ class GraphRuntimeState(BaseModel):
invoke_from: InvokeFrom
call_depth: int
graph: Graph
variable_pool: VariablePool
start_at: Optional[float] = None
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState)

View File

@ -0,0 +1,13 @@
from typing import Optional
from pydantic import BaseModel
from core.workflow.graph_engine.entities.graph import GraphParallel
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""

View File

@ -1,38 +0,0 @@
from datetime import datetime, timezone
from typing import Optional
from pydantic import BaseModel
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.runtime_node import RuntimeNode
from models.workflow import WorkflowNodeExecutionStatus
class RuntimeGraph(BaseModel):
runtime_nodes: dict[str, RuntimeNode] = {}
"""runtime nodes"""
def add_runtime_node(self, runtime_node: RuntimeNode) -> None:
self.runtime_nodes[runtime_node.id] = runtime_node
def add_link(self, source_runtime_node_id: str, target_runtime_node_id: str) -> None:
if source_runtime_node_id in self.runtime_nodes and target_runtime_node_id in self.runtime_nodes:
target_runtime_node = self.runtime_nodes[target_runtime_node_id]
target_runtime_node.predecessor_runtime_node_id = source_runtime_node_id
def runtime_node_finished(self, runtime_node_id: str, node_run_result: NodeRunResult) -> None:
if runtime_node_id in self.runtime_nodes:
runtime_node = self.runtime_nodes[runtime_node_id]
runtime_node.status = RuntimeNode.Status.SUCCESS \
if node_run_result.status == WorkflowNodeExecutionStatus.RUNNING \
else RuntimeNode.Status.FAILED
runtime_node.node_run_result = node_run_result
runtime_node.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.failed_reason = node_run_result.error
def runtime_node_paused(self, runtime_node_id: str, paused_by: Optional[str] = None) -> None:
if runtime_node_id in self.runtime_nodes:
runtime_node = self.runtime_nodes[runtime_node_id]
runtime_node.status = RuntimeNode.Status.PAUSED
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.paused_by = paused_by

View File

@ -1,48 +0,0 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.graph import GraphNode
class RuntimeNode(BaseModel):
class Status(Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random id for current runtime node"""
graph_node: GraphNode
"""graph node"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.PENDING
"""node status"""
start_at: Optional[datetime] = None
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
predecessor_runtime_node_id: Optional[str] = None
"""predecessor runtime node id"""

View File

@ -0,0 +1,111 @@
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class RouteNodeState(BaseModel):
class Status(Enum):
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
state_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
node_id: str
"""node id"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.RUNNING
"""node status"""
start_at: datetime
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(default_factory=dict)
"""graph state routes (source_node_state_id: target_node_state_id)"""
node_state_mapping: dict[str, RouteNodeState] = Field(default_factory=dict)
"""node state mapping (route_node_state_id: route_node_state)"""
def create_node_state(self, node_id: str) -> RouteNodeState:
"""
Create node state
:param node_id: node id
"""
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
self.node_state_mapping[state.state_id] = state
return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
"""
Add route to the graph state
:param source_node_state_id: source node state id
:param target_node_state_id: target node state id
"""
if source_node_state_id not in self.routes:
self.routes[source_node_state_id] = []
self.routes[source_node_state_id].append(target_node_state_id)
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
-> list[RouteNodeState]:
"""
Get routes with node state by source node id
:param source_node_state_id: source node state id
:return: routes with node state
"""
return [self.node_state_mapping[target_state_id]
for target_state_id in self.routes.get(source_node_state_id, [])]
def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None:
"""
Node finished
:param node_state_id: route node state id
:param run_result: run result
"""
if node_state_id not in self.node_state_mapping:
raise Exception(f"Route state {node_state_id} not found")
route = self.node_state_mapping[node_state_id]
if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {node_state_id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
route.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
route.status = RouteNodeState.Status.FAILED
route.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
route.node_run_result = run_result
route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)

View File

@ -22,6 +22,7 @@ class GraphEngine:
graph: Graph,
variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None:
self.graph = graph
self.graph_runtime_state = GraphRuntimeState(
tenant_id=tenant_id,
app_id=app_id,
@ -29,7 +30,6 @@ class GraphEngine:
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
graph=graph,
variable_pool=variable_pool
)
@ -43,3 +43,49 @@ class GraphEngine:
def run(self) -> Generator:
self.graph_runtime_state.start_at = time.perf_counter()
pass
# def next_node_ids(self, node_state_id: str) -> list[NextGraphNode]:
# """
# Get next node ids
#
# :param node_state_id: source node state id
# """
# # get current node ids in state
# node_run_state = self.graph_runtime_state.node_run_state
# graph = self.graph
# if not node_run_state.routes:
# return [NextGraphNode(node_id=graph.root_node_id)]
#
# route_final_graph_edges: list[GraphEdge] = []
# for route in route_state.routes[graph.root_node_id]:
# graph_edges = graph.edge_mapping.get(route.node_id)
# if not graph_edges:
# continue
#
# for edge in graph_edges:
# if edge.target_node_id not in route_state.routes:
# route_final_graph_edges.append(edge)
#
# next_graph_nodes = []
# for route_final_graph_edge in route_final_graph_edges:
# node_id = route_final_graph_edge.target_node_id
# # check condition
# if route_final_graph_edge.run_condition:
# result = ConditionManager.get_condition_handler(
# run_condition=route_final_graph_edge.run_condition
# ).check(
# source_node_id=route_final_graph_edge.source_node_id,
# target_node_id=route_final_graph_edge.target_node_id,
# graph=self
# )
#
# if not result:
# continue
#
# parallel = None
# if route_final_graph_edge.target_node_id in graph.node_parallel_mapping:
# parallel = graph.parallel_mapping[graph.node_parallel_mapping[node_id]]
#
# next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
#
# return next_graph_nodes