mirror of
https://github.com/langgenius/dify.git
synced 2026-03-19 05:37:42 +08:00
add graph engine test
This commit is contained in:
@ -1,23 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self, condition: RunCondition):
|
||||
def __init__(self,
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
source_node_id: str,
|
||||
target_node_id: str,
|
||||
graph: "Graph") -> bool:
|
||||
target_node_id: str) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param graph: graph
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,29 +1,33 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
source_node_id: str,
|
||||
target_node_id: str,
|
||||
graph: "Graph") -> bool:
|
||||
target_node_id: str) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param graph: graph
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
run_state = graph.run_state
|
||||
node_route_result = run_state.node_route_results.get(source_node_id)
|
||||
if not node_route_result:
|
||||
node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id)
|
||||
if not node_route_state:
|
||||
return False
|
||||
|
||||
if not node_route_result.edge_source_handle:
|
||||
run_result = node_route_state.node_run_result
|
||||
if not run_result:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == node_route_result.edge_source_handle
|
||||
if not run_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == run_result.edge_source_handle
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
source_node_id: str,
|
||||
target_node_id: str,
|
||||
graph: "Graph") -> bool:
|
||||
target_node_id: str) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param graph: graph
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
@ -21,10 +22,9 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
compare_result, _ = condition_processor.process(
|
||||
variable_pool=graph.run_state.variable_pool,
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
logical_operator="and",
|
||||
conditions=self.condition.conditions
|
||||
)
|
||||
|
||||
return compare_result
|
||||
|
||||
|
||||
@ -1,19 +1,35 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler:
|
||||
def get_condition_handler(
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
run_condition: RunCondition
|
||||
) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
|
||||
:param init_params: init params
|
||||
:param graph: graph
|
||||
:param run_condition: run condition
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(run_condition)
|
||||
return BranchIdentifyRunConditionHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(run_condition)
|
||||
return ConditionRunConditionHandlerHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
@ -50,10 +51,12 @@ class NodeRunStartedEvent(BaseNodeEvent):
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: list[str] = Field(..., description="from variable selector")
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
@ -61,7 +64,10 @@ class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
run_result: NodeRunResult = Field(
|
||||
default=NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED),
|
||||
description="run result"
|
||||
)
|
||||
reason: str = Field("", description="failed reason")
|
||||
|
||||
@model_validator(mode='before')
|
||||
|
||||
@ -3,13 +3,13 @@ import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, cast
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
@ -19,14 +19,21 @@ from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.base_node import UserFrom, node_classes
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import node_classes
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.test.test_node import TestNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowType
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -43,7 +50,8 @@ class GraphEngine:
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
variable_pool: VariablePool,
|
||||
callbacks: list[BaseWorkflowCallback]) -> None:
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int) -> None:
|
||||
self.graph = graph
|
||||
self.init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
@ -61,12 +69,8 @@ class GraphEngine:
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
||||
self.max_execution_steps = cast(int, max_execution_steps)
|
||||
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
|
||||
self.max_execution_time = cast(int, max_execution_time)
|
||||
|
||||
self.callbacks = callbacks
|
||||
self.max_execution_steps = max_execution_steps
|
||||
self.max_execution_time = max_execution_time
|
||||
|
||||
def run_in_block_mode(self):
|
||||
# TODO convert generator to result
|
||||
@ -92,7 +96,7 @@ class GraphEngine:
|
||||
return
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(reason=str(e))
|
||||
return
|
||||
raise e
|
||||
|
||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
next_node_id = start_node_id
|
||||
@ -118,7 +122,7 @@ class GraphEngine:
|
||||
)
|
||||
except Exception as e:
|
||||
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
|
||||
return
|
||||
raise e
|
||||
|
||||
previous_node_id = next_node_id
|
||||
|
||||
@ -141,11 +145,13 @@ class GraphEngine:
|
||||
for edge in edge_mappings:
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
run_condition=edge.run_condition
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
source_node_id=edge.source_node_id,
|
||||
target_node_id=edge.target_node_id,
|
||||
graph=self.graph
|
||||
)
|
||||
|
||||
if result:
|
||||
@ -250,7 +256,16 @@ class GraphEngine:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
# node_instance = node_cls( # type: ignore
|
||||
# config=node_config,
|
||||
# graph_init_params=self.init_params,
|
||||
# graph=self.graph,
|
||||
# graph_runtime_state=self.graph_runtime_state,
|
||||
# previous_node_id=previous_node_id
|
||||
# )
|
||||
|
||||
# init workflow run state
|
||||
node_instance = TestNode(
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
@ -268,24 +283,64 @@ class GraphEngine:
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
|
||||
try:
|
||||
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
run_result = None
|
||||
for item in generator:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
node_id=node_id,
|
||||
parallel_id=parallel_id,
|
||||
run_result=run_result,
|
||||
reason=run_result.error
|
||||
)
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
yield NodeRunSucceededEvent(
|
||||
node_id=node_id,
|
||||
parallel_id=parallel_id,
|
||||
run_result=run_result
|
||||
)
|
||||
|
||||
yield from generator
|
||||
self.graph_runtime_state.node_run_state.node_state_mapping[node_id] = RouteNodeState(
|
||||
node_id=node_id,
|
||||
start_at=start_at,
|
||||
status=RouteNodeState.Status.SUCCESS if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
else RouteNodeState.Status.FAILED,
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
node_run_result=run_result,
|
||||
failed_reason=run_result.error
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
|
||||
)
|
||||
|
||||
# todo append self.graph_runtime_state.node_run_state.routes
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
node_id=node_id,
|
||||
parallel_id=parallel_id,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
node_id=node_id,
|
||||
parallel_id=parallel_id,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context
|
||||
)
|
||||
|
||||
# todo record state
|
||||
|
||||
# trigger node run success event
|
||||
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
|
||||
except GenerateTaskStoppedException as e:
|
||||
# trigger node run failed event
|
||||
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
||||
return
|
||||
except Exception as e:
|
||||
# todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
|
||||
# trigger node run failed event
|
||||
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
||||
return
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user