mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 07:28:05 +08:00
completed graph init test
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
@ -33,7 +33,8 @@ class GraphNode(BaseModel):
|
||||
"""sub graph of the node, e.g. iteration/loop sub graph"""
|
||||
|
||||
def add_child(self, node_id: str) -> None:
|
||||
self.descendant_node_ids.append(node_id)
|
||||
if node_id not in self.descendant_node_ids:
|
||||
self.descendant_node_ids.append(node_id)
|
||||
|
||||
def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
|
||||
"""
|
||||
@ -56,6 +57,12 @@ class Graph(BaseModel):
|
||||
root_node: GraphNode
|
||||
"""root node of the graph"""
|
||||
|
||||
@model_validator(mode='after')
|
||||
def add_root_node(cls, values):
|
||||
root_node = values.root_node
|
||||
values.graph_nodes[root_node.id] = root_node
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph":
|
||||
"""
|
||||
@ -76,10 +83,7 @@ class Graph(BaseModel):
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
graph = cls(root_node=root_node)
|
||||
|
||||
graph._add_graph_node(graph.root_node)
|
||||
return graph
|
||||
return cls(root_node=root_node)
|
||||
|
||||
def add_edge(self, edge_config: dict,
|
||||
source_node_config: dict,
|
||||
@ -106,7 +110,10 @@ class Graph(BaseModel):
|
||||
if not target_node_id:
|
||||
return
|
||||
|
||||
source_node = self.graph_nodes[source_node_id]
|
||||
source_node = self.graph_nodes.get(source_node_id)
|
||||
if not source_node:
|
||||
return
|
||||
|
||||
source_node.add_child(target_node_id)
|
||||
|
||||
if target_node_id not in self.graph_nodes:
|
||||
@ -120,45 +127,66 @@ class Graph(BaseModel):
|
||||
sub_graph=target_node_sub_graph
|
||||
)
|
||||
|
||||
self._add_graph_node(target_graph_node)
|
||||
self.add_graph_node(target_graph_node)
|
||||
else:
|
||||
target_node = self.graph_nodes[target_node_id]
|
||||
target_node = self.graph_nodes.get(target_node_id)
|
||||
if not target_node:
|
||||
return
|
||||
|
||||
target_node.predecessor_node_id = source_node_id
|
||||
target_node.run_condition = run_condition
|
||||
target_node.source_edge_config = edge_config
|
||||
target_node.sub_graph = target_node_sub_graph
|
||||
|
||||
def get_root_node(self) -> Optional[GraphNode]:
|
||||
def get_leaf_nodes(self) -> list[GraphNode]:
|
||||
"""
|
||||
Get root node of the graph
|
||||
Get leaf nodes of the graph
|
||||
|
||||
:return: root node
|
||||
:return: leaf nodes
|
||||
"""
|
||||
return self.root_node
|
||||
leaf_nodes = []
|
||||
for node_id, graph_node in self.graph_nodes.items():
|
||||
if (
|
||||
not graph_node.descendant_node_ids # has no child
|
||||
or # or has only one child and the child is the root node
|
||||
(
|
||||
graph_node.descendant_node_ids
|
||||
and graph_node.descendant_node_ids[0] == self.root_node.id
|
||||
)
|
||||
):
|
||||
leaf_nodes.append(graph_node)
|
||||
|
||||
def get_descendants_graph(self, node_id: str) -> Optional["Graph"]:
|
||||
return leaf_nodes
|
||||
|
||||
def get_descendant_graphs(self, node_id: str) -> list["Graph"]:
|
||||
"""
|
||||
Get descendants graph of the specific node
|
||||
Get descendant graphs of the specific node
|
||||
|
||||
:param node_id: node id
|
||||
:return: descendants graph
|
||||
:return: descendant graphs
|
||||
"""
|
||||
if node_id not in self.graph_nodes:
|
||||
return None
|
||||
return []
|
||||
|
||||
graph_node = self.graph_nodes[node_id]
|
||||
if not graph_node.descendant_node_ids:
|
||||
return None
|
||||
graph_node = self.graph_nodes.get(node_id)
|
||||
if not graph_node or not graph_node.descendant_node_ids:
|
||||
return []
|
||||
|
||||
descendants_graph = Graph(root_node=graph_node)
|
||||
descendants_graph._add_graph_node(graph_node)
|
||||
descendant_graphs: list[Graph] = []
|
||||
for descendant_node_id in graph_node.descendant_node_ids:
|
||||
descendant_graph_node = self.graph_nodes.get(descendant_node_id)
|
||||
if not descendant_graph_node:
|
||||
continue
|
||||
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
||||
descendants_graph = Graph(root_node=descendant_graph_node)
|
||||
for sub_descendant_node_id in descendant_graph_node.descendant_node_ids:
|
||||
descendants_graph.add_descendants_graph_nodes(self, sub_descendant_node_id)
|
||||
|
||||
return descendants_graph
|
||||
descendant_graphs.append(descendants_graph)
|
||||
|
||||
def _add_graph_node(self, graph_node: GraphNode) -> None:
|
||||
return descendant_graphs
|
||||
|
||||
def add_graph_node(self, graph_node: GraphNode) -> None:
|
||||
"""
|
||||
Add graph node to the graph
|
||||
|
||||
@ -167,23 +195,24 @@ class Graph(BaseModel):
|
||||
if graph_node.id in self.graph_nodes:
|
||||
return
|
||||
|
||||
if len(self.graph_nodes) == 0:
|
||||
self.root_node = graph_node
|
||||
|
||||
self.graph_nodes[graph_node.id] = graph_node
|
||||
|
||||
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
|
||||
def add_descendants_graph_nodes(self, predecessor_graph: "Graph", node_id: str) -> None:
|
||||
"""
|
||||
Add descendants graph nodes
|
||||
|
||||
:param descendants_graph: descendants graph
|
||||
:param predecessor_graph: predecessor graph
|
||||
:param node_id: node id
|
||||
"""
|
||||
if node_id not in self.graph_nodes:
|
||||
if node_id not in predecessor_graph.graph_nodes:
|
||||
return
|
||||
|
||||
graph_node = self.graph_nodes[node_id]
|
||||
descendants_graph._add_graph_node(graph_node)
|
||||
graph_node = predecessor_graph.graph_nodes.get(node_id)
|
||||
if not graph_node:
|
||||
return
|
||||
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
||||
if graph_node.id not in self.graph_nodes:
|
||||
self.add_graph_node(graph_node)
|
||||
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self.add_descendants_graph_nodes(predecessor_graph, child_node_id)
|
||||
|
||||
@ -298,6 +298,11 @@ class WorkflowEntry:
|
||||
root_node_configs=root_node_configs
|
||||
)
|
||||
|
||||
# add edge from end node to first node of sub graph
|
||||
sub_graph_root_node_id = sub_graph.root_node.id
|
||||
for leaf_node in sub_graph.get_leaf_nodes():
|
||||
leaf_node.add_child(sub_graph_root_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get('sourceHandle'):
|
||||
|
||||
Reference in New Issue
Block a user