add graph engine test

This commit is contained in:
takatost
2024-07-16 16:37:37 +08:00
parent 00fb23d0c9
commit 00ec36d47c
17 changed files with 1122 additions and 904 deletions

View File

@ -1,23 +1,31 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition
class RunConditionHandler(ABC):
def __init__(self, condition: RunCondition):
def __init__(self,
init_params: GraphInitParams,
graph: Graph,
condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self,
graph_runtime_state: GraphRuntimeState,
source_node_id: str,
target_node_id: str,
graph: "Graph") -> bool:
target_node_id: str) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param source_node_id: source node id
:param target_node_id: target node id
:param graph: graph
:return: bool
"""
raise NotImplementedError

View File

@ -1,29 +1,33 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
source_node_id: str,
target_node_id: str,
graph: "Graph") -> bool:
target_node_id: str) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param source_node_id: source node id
:param target_node_id: target node id
:param graph: graph
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
run_state = graph.run_state
node_route_result = run_state.node_route_results.get(source_node_id)
if not node_route_result:
node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id)
if not node_route_state:
return False
if not node_route_result.edge_source_handle:
run_result = node_route_state.node_run_result
if not run_result:
return False
return self.condition.branch_identify == node_route_result.edge_source_handle
if not run_result.edge_source_handle:
return False
return self.condition.branch_identify == run_result.edge_source_handle

View File

@ -1,18 +1,19 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
source_node_id: str,
target_node_id: str,
graph: "Graph") -> bool:
target_node_id: str) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param source_node_id: source node id
:param target_node_id: target node id
:param graph: graph
:return: bool
"""
if not self.condition.conditions:
@ -21,10 +22,9 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
# process condition
condition_processor = ConditionProcessor()
compare_result, _ = condition_processor.process(
variable_pool=graph.run_state.variable_pool,
variable_pool=graph_runtime_state.variable_pool,
logical_operator="and",
conditions=self.condition.conditions
)
return compare_result

View File

@ -1,19 +1,35 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler:
def get_condition_handler(
init_params: GraphInitParams,
graph: Graph,
run_condition: RunCondition
) -> RunConditionHandler:
"""
Get condition handler
:param init_params: init params
:param graph: graph
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(run_condition)
return BranchIdentifyRunConditionHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)
else:
return ConditionRunConditionHandlerHandler(run_condition)
return ConditionRunConditionHandlerHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)

View File

@ -3,6 +3,7 @@ from typing import Optional
from pydantic import BaseModel, Field, model_validator
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class GraphEngineEvent(BaseModel):
@ -50,10 +51,12 @@ class NodeRunStartedEvent(BaseNodeEvent):
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(BaseNodeEvent):
@ -61,7 +64,10 @@ class NodeRunSucceededEvent(BaseNodeEvent):
class NodeRunFailedEvent(BaseNodeEvent):
run_result: NodeRunResult = Field(..., description="run result")
run_result: NodeRunResult = Field(
default=NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED),
description="run result"
)
reason: str = Field("", description="failed reason")
@model_validator(mode='before')

View File

@ -3,13 +3,13 @@ import queue
import time
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, cast
from datetime import datetime, timezone
from typing import Optional
from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
@ -19,14 +19,21 @@ from core.workflow.graph_engine.entities.event import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.base_node import UserFrom, node_classes
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import node_classes
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.test.test_node import TestNode
from extensions.ext_database import db
from models.workflow import WorkflowType
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
logger = logging.getLogger(__name__)
@ -43,7 +50,8 @@ class GraphEngine:
call_depth: int,
graph: Graph,
variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None:
max_execution_steps: int,
max_execution_time: int) -> None:
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
@ -61,12 +69,8 @@ class GraphEngine:
start_at=time.perf_counter()
)
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
self.max_execution_steps = cast(int, max_execution_steps)
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
self.max_execution_time = cast(int, max_execution_time)
self.callbacks = callbacks
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time
def run_in_block_mode(self):
# TODO convert generator to result
@ -92,7 +96,7 @@ class GraphEngine:
return
except Exception as e:
yield GraphRunFailedEvent(reason=str(e))
return
raise e
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
next_node_id = start_node_id
@ -118,7 +122,7 @@ class GraphEngine:
)
except Exception as e:
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
return
raise e
previous_node_id = next_node_id
@ -141,11 +145,13 @@ class GraphEngine:
for edge in edge_mappings:
if edge.run_condition:
result = ConditionManager.get_condition_handler(
run_condition=edge.run_condition
init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
source_node_id=edge.source_node_id,
target_node_id=edge.target_node_id,
graph=self.graph
)
if result:
@ -250,7 +256,16 @@ class GraphEngine:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
# init workflow run state
node_instance = node_cls( # type: ignore
# node_instance = node_cls( # type: ignore
# config=node_config,
# graph_init_params=self.init_params,
# graph=self.graph,
# graph_runtime_state=self.graph_runtime_state,
# previous_node_id=previous_node_id
# )
# init workflow run state
node_instance = TestNode(
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
@ -268,24 +283,64 @@ class GraphEngine:
self.graph_runtime_state.node_run_steps += 1
try:
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
# run node
generator = node_instance.run()
run_result = None
for item in generator:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
node_id=node_id,
parallel_id=parallel_id,
run_result=run_result,
reason=run_result.error
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
yield NodeRunSucceededEvent(
node_id=node_id,
parallel_id=parallel_id,
run_result=run_result
)
yield from generator
self.graph_runtime_state.node_run_state.node_state_mapping[node_id] = RouteNodeState(
node_id=node_id,
start_at=start_at,
status=RouteNodeState.Status.SUCCESS if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
else RouteNodeState.Status.FAILED,
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
node_run_result=run_result,
failed_reason=run_result.error
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
)
# todo append self.graph_runtime_state.node_run_state.routes
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
node_id=node_id,
parallel_id=parallel_id,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
node_id=node_id,
parallel_id=parallel_id,
retriever_resources=item.retriever_resources,
context=item.context
)
# todo record state
# trigger node run success event
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
except GenerateTaskStoppedException as e:
# trigger node run failed event
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return
except Exception as e:
# todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
# trigger node run failed event
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
raise e
finally:
db.session.close()