mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 13:16:16 +08:00
refactor runtime
This commit is contained in:
@ -3,9 +3,7 @@ from typing import Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
@ -28,33 +26,6 @@ class GraphParallel(BaseModel):
|
||||
"""parent parallel id if exists"""
|
||||
|
||||
|
||||
class GraphStateRoute(BaseModel):
|
||||
route_id: str
|
||||
"""route id"""
|
||||
|
||||
node_id: str
|
||||
"""node id"""
|
||||
|
||||
|
||||
class GraphState(BaseModel):
|
||||
routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict)
|
||||
"""graph state routes (source_node_id: routes)"""
|
||||
|
||||
variable_pool: VariablePool
|
||||
"""variable pool"""
|
||||
|
||||
node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict)
|
||||
"""node results in route (node_id: run_result)"""
|
||||
|
||||
|
||||
class NextGraphNode(BaseModel):
|
||||
node_id: str
|
||||
"""next node id"""
|
||||
|
||||
parallel: Optional[GraphParallel] = None
|
||||
"""parallel"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str
|
||||
"""root node id of the graph"""
|
||||
@ -71,19 +42,14 @@ class Graph(BaseModel):
|
||||
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
|
||||
"""graph node parallel mapping (node id: parallel id)"""
|
||||
|
||||
run_state: GraphState
|
||||
"""graph run state"""
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
graph_config: dict,
|
||||
variable_pool: VariablePool,
|
||||
root_node_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
:param graph_config: graph config
|
||||
:param variable_pool: variable pool
|
||||
:param root_node_id: root node id
|
||||
:return: graph
|
||||
"""
|
||||
@ -149,7 +115,7 @@ class Graph(BaseModel):
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_id = next((node_config for node_config in root_node_configs
|
||||
root_node_id = next((node_config.get("id") for node_config in root_node_configs
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
|
||||
|
||||
if not root_node_id or root_node_id not in root_node_ids:
|
||||
@ -178,80 +144,12 @@ class Graph(BaseModel):
|
||||
root_node_id=root_node_id,
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
run_state=GraphState(
|
||||
variable_pool=variable_pool
|
||||
),
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(cls,
|
||||
node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param node_ids: node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param node_id: node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(node_id, []):
|
||||
if graph_edge.target_node_id in node_ids:
|
||||
continue
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
def next_node_ids(self) -> list[NextGraphNode]:
|
||||
"""
|
||||
Get next node ids
|
||||
"""
|
||||
# get current node ids in state
|
||||
if not self.run_state.routes:
|
||||
return [NextGraphNode(node_id=self.root_node_id)]
|
||||
|
||||
route_final_graph_edges: list[GraphEdge] = []
|
||||
for route in self.run_state.routes[self.root_node_id]:
|
||||
graph_edges = self.edge_mapping.get(route.node_id)
|
||||
if not graph_edges:
|
||||
continue
|
||||
|
||||
for edge in graph_edges:
|
||||
if edge.target_node_id not in self.run_state.routes:
|
||||
route_final_graph_edges.append(edge)
|
||||
|
||||
next_graph_nodes = []
|
||||
for route_final_graph_edge in route_final_graph_edges:
|
||||
node_id = route_final_graph_edge.target_node_id
|
||||
# check condition
|
||||
if route_final_graph_edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
run_condition=route_final_graph_edge.run_condition
|
||||
).check(
|
||||
source_node_id=route_final_graph_edge.source_node_id,
|
||||
target_node_id=route_final_graph_edge.target_node_id,
|
||||
graph=self
|
||||
)
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
parallel = None
|
||||
if route_final_graph_edge.target_node_id in self.node_parallel_mapping:
|
||||
parallel = self.parallel_mapping[self.node_parallel_mapping[node_id]]
|
||||
|
||||
next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
|
||||
|
||||
return next_graph_nodes
|
||||
|
||||
def add_extra_edge(self, source_node_id: str,
|
||||
target_node_id: str,
|
||||
run_condition: Optional[RunCondition] = None) -> None:
|
||||
@ -295,6 +193,29 @@ class Graph(BaseModel):
|
||||
|
||||
return leaf_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(cls,
|
||||
node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param node_ids: node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param node_id: node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(node_id, []):
|
||||
if graph_edge.target_node_id in node_ids:
|
||||
continue
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallels(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
|
||||
@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
# init params
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
@ -17,10 +17,10 @@ class GraphRuntimeState(BaseModel):
|
||||
invoke_from: InvokeFrom
|
||||
call_depth: int
|
||||
|
||||
graph: Graph
|
||||
variable_pool: VariablePool
|
||||
|
||||
|
||||
start_at: Optional[float] = None
|
||||
total_tokens: int = 0
|
||||
node_run_steps: int = 0
|
||||
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
|
||||
|
||||
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState)
|
||||
|
||||
13
api/core/workflow/graph_engine/entities/next_graph_node.py
Normal file
13
api/core/workflow/graph_engine/entities/next_graph_node.py
Normal file
@ -0,0 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import GraphParallel
|
||||
|
||||
|
||||
class NextGraphNode(BaseModel):
|
||||
node_id: str
|
||||
"""next node id"""
|
||||
|
||||
parallel: Optional[GraphParallel] = None
|
||||
"""parallel"""
|
||||
@ -1,38 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.entities.runtime_node import RuntimeNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RuntimeGraph(BaseModel):
|
||||
runtime_nodes: dict[str, RuntimeNode] = {}
|
||||
"""runtime nodes"""
|
||||
|
||||
def add_runtime_node(self, runtime_node: RuntimeNode) -> None:
|
||||
self.runtime_nodes[runtime_node.id] = runtime_node
|
||||
|
||||
def add_link(self, source_runtime_node_id: str, target_runtime_node_id: str) -> None:
|
||||
if source_runtime_node_id in self.runtime_nodes and target_runtime_node_id in self.runtime_nodes:
|
||||
target_runtime_node = self.runtime_nodes[target_runtime_node_id]
|
||||
target_runtime_node.predecessor_runtime_node_id = source_runtime_node_id
|
||||
|
||||
def runtime_node_finished(self, runtime_node_id: str, node_run_result: NodeRunResult) -> None:
|
||||
if runtime_node_id in self.runtime_nodes:
|
||||
runtime_node = self.runtime_nodes[runtime_node_id]
|
||||
runtime_node.status = RuntimeNode.Status.SUCCESS \
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.RUNNING \
|
||||
else RuntimeNode.Status.FAILED
|
||||
runtime_node.node_run_result = node_run_result
|
||||
runtime_node.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
runtime_node.failed_reason = node_run_result.error
|
||||
|
||||
def runtime_node_paused(self, runtime_node_id: str, paused_by: Optional[str] = None) -> None:
|
||||
if runtime_node_id in self.runtime_nodes:
|
||||
runtime_node = self.runtime_nodes[runtime_node_id]
|
||||
runtime_node.status = RuntimeNode.Status.PAUSED
|
||||
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
runtime_node.paused_by = paused_by
|
||||
@ -1,48 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.entities.graph import GraphNode
|
||||
|
||||
|
||||
class RuntimeNode(BaseModel):
|
||||
class Status(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""random id for current runtime node"""
|
||||
|
||||
graph_node: GraphNode
|
||||
"""graph node"""
|
||||
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.PENDING
|
||||
"""node status"""
|
||||
|
||||
start_at: Optional[datetime] = None
|
||||
"""start time"""
|
||||
|
||||
paused_at: Optional[datetime] = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: Optional[str] = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
predecessor_runtime_node_id: Optional[str] = None
|
||||
"""predecessor runtime node id"""
|
||||
111
api/core/workflow/graph_engine/entities/runtime_route_state.py
Normal file
111
api/core/workflow/graph_engine/entities/runtime_route_state.py
Normal file
@ -0,0 +1,111 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
class Status(Enum):
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
state_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
|
||||
node_id: str
|
||||
"""node id"""
|
||||
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.RUNNING
|
||||
"""node status"""
|
||||
|
||||
start_at: datetime
|
||||
"""start time"""
|
||||
|
||||
paused_at: Optional[datetime] = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: Optional[str] = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(default_factory=dict)
|
||||
"""graph state routes (source_node_state_id: target_node_state_id)"""
|
||||
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(default_factory=dict)
|
||||
"""node state mapping (route_node_state_id: route_node_state)"""
|
||||
|
||||
def create_node_state(self, node_id: str) -> RouteNodeState:
|
||||
"""
|
||||
Create node state
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
|
||||
self.node_state_mapping[state.state_id] = state
|
||||
return state
|
||||
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
|
||||
"""
|
||||
Add route to the graph state
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:param target_node_state_id: target node state id
|
||||
"""
|
||||
if source_node_state_id not in self.routes:
|
||||
self.routes[source_node_state_id] = []
|
||||
|
||||
self.routes[source_node_state_id].append(target_node_state_id)
|
||||
|
||||
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
|
||||
-> list[RouteNodeState]:
|
||||
"""
|
||||
Get routes with node state by source node id
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:return: routes with node state
|
||||
"""
|
||||
return [self.node_state_mapping[target_state_id]
|
||||
for target_state_id in self.routes.get(source_node_state_id, [])]
|
||||
|
||||
def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None:
|
||||
"""
|
||||
Node finished
|
||||
|
||||
:param node_state_id: route node state id
|
||||
:param run_result: run result
|
||||
"""
|
||||
if node_state_id not in self.node_state_mapping:
|
||||
raise Exception(f"Route state {node_state_id} not found")
|
||||
|
||||
route = self.node_state_mapping[node_state_id]
|
||||
|
||||
if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
|
||||
raise Exception(f"Route state {node_state_id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
route.status = RouteNodeState.Status.SUCCESS
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
route.status = RouteNodeState.Status.FAILED
|
||||
route.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
route.node_run_result = run_result
|
||||
route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
@ -22,6 +22,7 @@ class GraphEngine:
|
||||
graph: Graph,
|
||||
variable_pool: VariablePool,
|
||||
callbacks: list[BaseWorkflowCallback]) -> None:
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
@ -29,7 +30,6 @@ class GraphEngine:
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
|
||||
@ -43,3 +43,49 @@ class GraphEngine:
|
||||
def run(self) -> Generator:
|
||||
self.graph_runtime_state.start_at = time.perf_counter()
|
||||
pass
|
||||
|
||||
# def next_node_ids(self, node_state_id: str) -> list[NextGraphNode]:
|
||||
# """
|
||||
# Get next node ids
|
||||
#
|
||||
# :param node_state_id: source node state id
|
||||
# """
|
||||
# # get current node ids in state
|
||||
# node_run_state = self.graph_runtime_state.node_run_state
|
||||
# graph = self.graph
|
||||
# if not node_run_state.routes:
|
||||
# return [NextGraphNode(node_id=graph.root_node_id)]
|
||||
#
|
||||
# route_final_graph_edges: list[GraphEdge] = []
|
||||
# for route in route_state.routes[graph.root_node_id]:
|
||||
# graph_edges = graph.edge_mapping.get(route.node_id)
|
||||
# if not graph_edges:
|
||||
# continue
|
||||
#
|
||||
# for edge in graph_edges:
|
||||
# if edge.target_node_id not in route_state.routes:
|
||||
# route_final_graph_edges.append(edge)
|
||||
#
|
||||
# next_graph_nodes = []
|
||||
# for route_final_graph_edge in route_final_graph_edges:
|
||||
# node_id = route_final_graph_edge.target_node_id
|
||||
# # check condition
|
||||
# if route_final_graph_edge.run_condition:
|
||||
# result = ConditionManager.get_condition_handler(
|
||||
# run_condition=route_final_graph_edge.run_condition
|
||||
# ).check(
|
||||
# source_node_id=route_final_graph_edge.source_node_id,
|
||||
# target_node_id=route_final_graph_edge.target_node_id,
|
||||
# graph=self
|
||||
# )
|
||||
#
|
||||
# if not result:
|
||||
# continue
|
||||
#
|
||||
# parallel = None
|
||||
# if route_final_graph_edge.target_node_id in graph.node_parallel_mapping:
|
||||
# parallel = graph.parallel_mapping[graph.node_parallel_mapping[node_id]]
|
||||
#
|
||||
# next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
|
||||
#
|
||||
# return next_graph_nodes
|
||||
|
||||
Reference in New Issue
Block a user