mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
graph engine implement
This commit is contained in:
116
api/core/workflow/graph_engine/entities/event.py
Normal file
116
api/core/workflow/graph_engine/entities/event.py
Normal file
@ -0,0 +1,116 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunBackToRootEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Node Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
node_id: str = Field(..., description="node id")
|
||||
parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel")
|
||||
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
reason: str = Field("", description="failed reason")
|
||||
|
||||
@model_validator(mode='before')
|
||||
def init_reason(cls, values: dict) -> dict:
|
||||
if not values.get("reason"):
|
||||
values["reason"] = values.get("run_result").error or "Unknown error"
|
||||
return values
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
|
||||
|
||||
class ParallelRunStartedEvent(BaseParallelEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelRunSucceededEvent(BaseParallelEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelRunFailedEvent(BaseParallelEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Iteration Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration id")
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
pass
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
pass
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelEvent | BaseIterationEvent
|
||||
@ -8,48 +8,37 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str
|
||||
"""source node id"""
|
||||
|
||||
target_node_id: str
|
||||
"""target node id"""
|
||||
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""condition to run the edge"""
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = Field(None, description="run condition")
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""random uuid parallel id"""
|
||||
|
||||
start_from_node_id: str
|
||||
"""start from node id"""
|
||||
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if exists"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = Field(None, description="parent parallel id")
|
||||
end_to_node_id: Optional[str] = Field(None, description="end to node id")
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str
|
||||
"""root node id of the graph"""
|
||||
|
||||
node_ids: list[str] = Field(default_factory=list)
|
||||
"""graph node ids"""
|
||||
|
||||
node_id_config_mapping: dict[str, dict] = Field(default_factory=list)
|
||||
"""node configs mapping (node id: node config)"""
|
||||
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
|
||||
"""graph edge mapping (source node id: edges)"""
|
||||
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict)
|
||||
"""graph parallel mapping (parallel id: parallel)"""
|
||||
|
||||
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
|
||||
"""graph node parallel mapping (node id: parallel id)"""
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=list,
|
||||
description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict,
|
||||
description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
|
||||
17
api/core/workflow/graph_engine/entities/graph_init_params.py
Normal file
17
api/core/workflow/graph_engine/entities/graph_init_params.py
Normal file
@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
@ -1,25 +1,14 @@
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
# init params
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
call_depth: int
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
|
||||
variable_pool: VariablePool
|
||||
start_at: float = Field(..., description="start time")
|
||||
total_tokens: int = Field(0, description="total tokens")
|
||||
node_run_steps: int = Field(0, description="node run steps")
|
||||
|
||||
start_at: float
|
||||
total_tokens: int = 0
|
||||
node_run_steps: int = 0
|
||||
|
||||
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState)
|
||||
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state")
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
@ -13,10 +13,20 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.base_node import UserFrom, node_classes
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -25,6 +35,8 @@ logger = logging.getLogger(__name__)
|
||||
class GraphEngine:
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
@ -33,13 +45,18 @@ class GraphEngine:
|
||||
variable_pool: VariablePool,
|
||||
callbacks: list[BaseWorkflowCallback]) -> None:
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
self.init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
call_depth=call_depth
|
||||
)
|
||||
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
@ -55,31 +72,31 @@ class GraphEngine:
|
||||
# TODO convert generator to result
|
||||
pass
|
||||
|
||||
def run(self) -> Generator:
|
||||
# TODO trigger graph run start event
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
# trigger graph run start event
|
||||
yield GraphRunStartedEvent()
|
||||
|
||||
try:
|
||||
# TODO run graph
|
||||
rst = self._run(start_node_id=self.graph.root_node_id)
|
||||
except GraphRunFailedError as e:
|
||||
# TODO self._graph_run_failed(
|
||||
# error=e.error,
|
||||
# callbacks=callbacks
|
||||
# )
|
||||
pass
|
||||
# run graph
|
||||
generator = self._run(start_node_id=self.graph.root_node_id)
|
||||
for item in generator:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(reason=item.reason)
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
yield GraphRunSucceededEvent()
|
||||
except (GraphRunFailedError, NodeRunFailedError) as e:
|
||||
yield GraphRunFailedEvent(reason=e.error)
|
||||
return
|
||||
except Exception as e:
|
||||
# TODO self._workflow_run_failed(
|
||||
# error=str(e),
|
||||
# callbacks=callbacks
|
||||
# )
|
||||
pass
|
||||
yield GraphRunFailedEvent(reason=str(e))
|
||||
return
|
||||
|
||||
# TODO trigger graph run success event
|
||||
|
||||
yield rst
|
||||
|
||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None):
|
||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
next_node_id = start_node_id
|
||||
previous_node_id = None
|
||||
while True:
|
||||
# max steps reached
|
||||
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
||||
@ -92,10 +109,18 @@ class GraphEngine:
|
||||
):
|
||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||
|
||||
# run node TODO generator
|
||||
yield from self._run_node(node_id=next_node_id)
|
||||
try:
|
||||
# run node
|
||||
yield from self._run_node(
|
||||
node_id=next_node_id,
|
||||
previous_node_id=previous_node_id,
|
||||
parallel_id=in_parallel_id
|
||||
)
|
||||
except Exception as e:
|
||||
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
|
||||
return
|
||||
|
||||
# todo if failed, break
|
||||
previous_node_id = next_node_id
|
||||
|
||||
# get next node ids
|
||||
edge_mappings = self.graph.edge_mapping.get(next_node_id)
|
||||
@ -135,11 +160,11 @@ class GraphEngine:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
|
||||
if not parallel_id:
|
||||
raise GraphRunFailedError('Node related parallel not found.')
|
||||
raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.')
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
raise GraphRunFailedError('Parallel not found.')
|
||||
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
||||
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
@ -149,8 +174,9 @@ class GraphEngine:
|
||||
for edge in edge_mappings:
|
||||
futures.append(thread_pool.submit(
|
||||
self._run_parallel_node,
|
||||
flask_app=current_app._get_current_object(),
|
||||
parallel_start_node_id=edge.source_node_id,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel
|
||||
q=q
|
||||
))
|
||||
|
||||
@ -165,8 +191,9 @@ class GraphEngine:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
# not necessary
|
||||
# for future in as_completed(futures):
|
||||
# future.result()
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
@ -178,48 +205,61 @@ class GraphEngine:
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_node(self, flask_app: Flask, parallel_start_node_id: str, q: queue.Queue) -> None:
|
||||
def _run_parallel_node(self,
|
||||
flask_app: Flask,
|
||||
parallel_id: str,
|
||||
parallel_start_node_id: str,
|
||||
q: queue.Queue) -> None:
|
||||
"""
|
||||
Run parallel nodes
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
in_parallel_id = self.graph.node_parallel_mapping.get(parallel_start_node_id)
|
||||
if not in_parallel_id:
|
||||
q.put(None)
|
||||
return
|
||||
|
||||
# run node TODO generator
|
||||
rst = self._run(
|
||||
# run node
|
||||
generator = self._run(
|
||||
start_node_id=parallel_start_node_id,
|
||||
in_parallel_id=in_parallel_id
|
||||
in_parallel_id=parallel_id
|
||||
)
|
||||
|
||||
if not rst:
|
||||
q.put(None)
|
||||
return
|
||||
|
||||
for item in rst:
|
||||
q.put(item)
|
||||
|
||||
q.put(None)
|
||||
if generator:
|
||||
for item in generator:
|
||||
q.put(item)
|
||||
except Exception:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
finally:
|
||||
q.put(None)
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(self, node_id: str) -> Generator:
|
||||
def _run_node(self,
|
||||
node_id: str,
|
||||
previous_node_id: Optional[str] = None,
|
||||
parallel_id: Optional[str] = None
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
# get node config
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError('Node not found.')
|
||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||
|
||||
# todo convert to specific node
|
||||
# convert to specific node
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
|
||||
# todo trigger node run start event
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_node_id=previous_node_id
|
||||
)
|
||||
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(node_id=node_id, parallel_id=parallel_id)
|
||||
|
||||
db.session.close()
|
||||
|
||||
@ -229,28 +269,25 @@ class GraphEngine:
|
||||
|
||||
try:
|
||||
# run node
|
||||
rst = node.run(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
graph=self.graph,
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
generator = node_instance.run()
|
||||
|
||||
yield from rst
|
||||
yield from generator
|
||||
|
||||
# todo record state
|
||||
|
||||
# trigger node run success event
|
||||
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
|
||||
except GenerateTaskStoppedException as e:
|
||||
# TODO yield failed
|
||||
# todo trigger node run failed event
|
||||
pass
|
||||
# trigger node run failed event
|
||||
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
||||
return
|
||||
except Exception as e:
|
||||
# logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
|
||||
# TODO yield failed
|
||||
# todo trigger node run failed event
|
||||
pass
|
||||
|
||||
# todo trigger node run success event
|
||||
|
||||
db.session.close()
|
||||
# todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
|
||||
# trigger node run failed event
|
||||
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
||||
return
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
@ -265,3 +302,8 @@ class GraphEngine:
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
||||
|
||||
class NodeRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
||||
Reference in New Issue
Block a user