completed graph init test

This commit is contained in:
takatost
2024-07-04 15:40:20 +08:00
parent 0f19b2a986
commit 1b6cd975f3
3 changed files with 531 additions and 36 deletions

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
@ -33,7 +33,8 @@ class GraphNode(BaseModel):
"""sub graph of the node, e.g. iteration/loop sub graph"""
def add_child(self, node_id: str) -> None:
self.descendant_node_ids.append(node_id)
if node_id not in self.descendant_node_ids:
self.descendant_node_ids.append(node_id)
def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
"""
@ -56,6 +57,12 @@ class Graph(BaseModel):
root_node: GraphNode
"""root node of the graph"""
@model_validator(mode='after')
def add_root_node(cls, values):
root_node = values.root_node
values.graph_nodes[root_node.id] = root_node
return values
@classmethod
def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph":
"""
@ -76,10 +83,7 @@ class Graph(BaseModel):
run_condition=run_condition
)
graph = cls(root_node=root_node)
graph._add_graph_node(graph.root_node)
return graph
return cls(root_node=root_node)
def add_edge(self, edge_config: dict,
source_node_config: dict,
@ -106,7 +110,10 @@ class Graph(BaseModel):
if not target_node_id:
return
source_node = self.graph_nodes[source_node_id]
source_node = self.graph_nodes.get(source_node_id)
if not source_node:
return
source_node.add_child(target_node_id)
if target_node_id not in self.graph_nodes:
@ -120,45 +127,66 @@ class Graph(BaseModel):
sub_graph=target_node_sub_graph
)
self._add_graph_node(target_graph_node)
self.add_graph_node(target_graph_node)
else:
target_node = self.graph_nodes[target_node_id]
target_node = self.graph_nodes.get(target_node_id)
if not target_node:
return
target_node.predecessor_node_id = source_node_id
target_node.run_condition = run_condition
target_node.source_edge_config = edge_config
target_node.sub_graph = target_node_sub_graph
def get_root_node(self) -> Optional[GraphNode]:
def get_leaf_nodes(self) -> list[GraphNode]:
"""
Get root node of the graph
Get leaf nodes of the graph
:return: root node
:return: leaf nodes
"""
return self.root_node
leaf_nodes = []
for node_id, graph_node in self.graph_nodes.items():
if (
not graph_node.descendant_node_ids # has no child
or # or has only one child and the child is the root node
(
graph_node.descendant_node_ids
and graph_node.descendant_node_ids[0] == self.root_node.id
)
):
leaf_nodes.append(graph_node)
def get_descendants_graph(self, node_id: str) -> Optional["Graph"]:
return leaf_nodes
def get_descendant_graphs(self, node_id: str) -> list["Graph"]:
"""
Get descendants graph of the specific node
Get descendant graphs of the specific node
:param node_id: node id
:return: descendants graph
:return: descendant graphs
"""
if node_id not in self.graph_nodes:
return None
return []
graph_node = self.graph_nodes[node_id]
if not graph_node.descendant_node_ids:
return None
graph_node = self.graph_nodes.get(node_id)
if not graph_node or not graph_node.descendant_node_ids:
return []
descendants_graph = Graph(root_node=graph_node)
descendants_graph._add_graph_node(graph_node)
descendant_graphs: list[Graph] = []
for descendant_node_id in graph_node.descendant_node_ids:
descendant_graph_node = self.graph_nodes.get(descendant_node_id)
if not descendant_graph_node:
continue
for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
descendants_graph = Graph(root_node=descendant_graph_node)
for sub_descendant_node_id in descendant_graph_node.descendant_node_ids:
descendants_graph.add_descendants_graph_nodes(self, sub_descendant_node_id)
return descendants_graph
descendant_graphs.append(descendants_graph)
def _add_graph_node(self, graph_node: GraphNode) -> None:
return descendant_graphs
def add_graph_node(self, graph_node: GraphNode) -> None:
"""
Add graph node to the graph
@ -167,23 +195,24 @@ class Graph(BaseModel):
if graph_node.id in self.graph_nodes:
return
if len(self.graph_nodes) == 0:
self.root_node = graph_node
self.graph_nodes[graph_node.id] = graph_node
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
def add_descendants_graph_nodes(self, predecessor_graph: "Graph", node_id: str) -> None:
"""
Add descendants graph nodes
:param descendants_graph: descendants graph
:param predecessor_graph: predecessor graph
:param node_id: node id
"""
if node_id not in self.graph_nodes:
if node_id not in predecessor_graph.graph_nodes:
return
graph_node = self.graph_nodes[node_id]
descendants_graph._add_graph_node(graph_node)
graph_node = predecessor_graph.graph_nodes.get(node_id)
if not graph_node:
return
for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
if graph_node.id not in self.graph_nodes:
self.add_graph_node(graph_node)
for child_node_id in graph_node.descendant_node_ids:
self.add_descendants_graph_nodes(predecessor_graph, child_node_id)

View File

@ -298,6 +298,11 @@ class WorkflowEntry:
root_node_configs=root_node_configs
)
# add edge from end node to first node of sub graph
sub_graph_root_node_id = sub_graph.root_node.id
for leaf_node in sub_graph.get_leaf_nodes():
leaf_node.add_child(sub_graph_root_node_id)
# parse run condition
run_condition = None
if edge_config.get('sourceHandle'):