Merge remote-tracking branch 'upstream/feat/queue-based-graph-engine' into feat/rag-2

This commit is contained in:
QuantumGhost
2025-09-17 18:00:48 +08:00
102 changed files with 2961 additions and 2025 deletions

View File

@ -17,7 +17,6 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
output_schema=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)

View File

@ -95,17 +95,3 @@ class TestGraphRuntimeState:
# Test add_tokens validation
with pytest.raises(ValueError):
state.add_tokens(-1)
def test_deep_copy_for_nested_objects(self):
variable_pool = VariablePool()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
# Test deep copy for nested dict
nested_data = {"level1": {"level2": {"value": "test"}}}
state.set_output("nested", nested_data)
retrieved = state.get_output("nested")
retrieved["level1"]["level2"]["value"] = "modified"
# Original should remain unchanged
assert state.get_output("nested")["level1"]["level2"]["value"] == "test"

View File

@ -3,14 +3,12 @@
import time
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent
from models.enums import UserFrom
def test_abort_command():
@ -42,18 +40,9 @@ def test_abort_command():
# Create GraphEngine with same shared runtime state
engine = GraphEngine(
tenant_id="test",
app_id="test",
workflow_id="test_workflow",
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=mock_graph,
graph_config={},
graph_runtime_state=shared_runtime_state, # Use shared instance
max_execution_steps=100,
max_execution_time=10,
command_channel=command_channel,
)

View File

@ -6,7 +6,6 @@ This test validates that:
- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output)
"""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import NodeType
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
@ -16,7 +15,6 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from models.enums import UserFrom
from .test_table_runner import TableTestRunner
@ -40,23 +38,11 @@ def test_streaming_output_with_blocking_equals_one():
use_mock_factory=True,
)
workflow_config = fixture_data.get("workflow", {})
graph_config = workflow_config.get("graph", {})
# Create and run the engine
engine = GraphEngine(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
graph=graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=30,
command_channel=InMemoryChannel(),
)
@ -147,23 +133,11 @@ def test_streaming_output_with_blocking_not_equals_one():
use_mock_factory=True,
)
workflow_config = fixture_data.get("workflow", {})
graph_config = workflow_config.get("graph", {})
# Create and run the engine
engine = GraphEngine(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
graph=graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=30,
command_channel=InMemoryChannel(),
)

View File

@ -9,7 +9,6 @@ import contextvars
import queue
import threading
import time
from typing import Optional
from unittest.mock import MagicMock
from flask import Flask, g
@ -59,7 +58,7 @@ class TestContextPreservation:
context = contextvars.copy_context()
# Variable to store value from worker
worker_value: Optional[str] = None
worker_value: str | None = None
def worker_task() -> None:
nonlocal worker_value
@ -120,7 +119,7 @@ class TestContextPreservation:
test_node = MagicMock(spec=Node)
# Variable to capture context inside node execution
captured_value: Optional[str] = None
captured_value: str | None = None
context_available_in_node = False
def mock_run() -> list[GraphNodeEventBase]:

View File

@ -10,11 +10,9 @@ import time
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
from models.enums import UserFrom
# Import the test framework from the new module
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
@ -460,18 +458,9 @@ def test_layer_system_basic():
# Create engine with layer
engine = GraphEngine(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_config=fixture_data.get("workflow", {}).get("graph", {}),
graph_runtime_state=graph_runtime_state,
max_execution_steps=300,
max_execution_time=60,
command_channel=InMemoryChannel(),
)
@ -525,18 +514,9 @@ def test_layer_chaining():
# Create engine
engine = GraphEngine(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_config=fixture_data.get("workflow", {}).get("graph", {}),
graph_runtime_state=graph_runtime_state,
max_execution_steps=300,
max_execution_time=60,
command_channel=InMemoryChannel(),
)
@ -581,18 +561,9 @@ def test_layer_error_handling():
# Create engine with faulty layer
engine = GraphEngine(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_config=fixture_data.get("workflow", {}).get("graph", {}),
graph_runtime_state=graph_runtime_state,
max_execution_steps=300,
max_execution_time=60,
command_channel=InMemoryChannel(),
)

View File

@ -0,0 +1,194 @@
"""Unit tests for GraphExecution serialization helpers."""
from __future__ import annotations
import json
from collections import deque
from unittest.mock import MagicMock
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.graph_engine.domain import GraphExecution
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
from core.workflow.graph_engine.response_coordinator.path import Path
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
from core.workflow.graph_events import NodeRunStreamChunkEvent
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
class CustomGraphExecutionError(Exception):
"""Custom exception used to verify error serialization."""
def test_graph_execution_serialization_round_trip() -> None:
"""GraphExecution serialization restores full aggregate state."""
# Arrange
execution = GraphExecution(workflow_id="wf-1")
execution.start()
node_a = execution.get_or_create_node_execution("node-a")
node_a.mark_started(execution_id="exec-1")
node_a.increment_retry()
node_a.mark_failed("boom")
node_b = execution.get_or_create_node_execution("node-b")
node_b.mark_skipped()
execution.fail(CustomGraphExecutionError("serialization failure"))
# Act
serialized = execution.dumps()
payload = json.loads(serialized)
restored = GraphExecution(workflow_id="wf-1")
restored.loads(serialized)
# Assert
assert payload["type"] == "GraphExecution"
assert payload["version"] == "1.0"
assert restored.workflow_id == "wf-1"
assert restored.started is True
assert restored.completed is True
assert restored.aborted is False
assert isinstance(restored.error, CustomGraphExecutionError)
assert str(restored.error) == "serialization failure"
assert set(restored.node_executions) == {"node-a", "node-b"}
restored_node_a = restored.node_executions["node-a"]
assert restored_node_a.state is NodeState.TAKEN
assert restored_node_a.retry_count == 1
assert restored_node_a.execution_id == "exec-1"
assert restored_node_a.error == "boom"
restored_node_b = restored.node_executions["node-b"]
assert restored_node_b.state is NodeState.SKIPPED
assert restored_node_b.retry_count == 0
assert restored_node_b.execution_id is None
assert restored_node_b.error is None
def test_graph_execution_loads_replaces_existing_state() -> None:
"""loads replaces existing runtime data with serialized snapshot."""
# Arrange
source = GraphExecution(workflow_id="wf-2")
source.start()
source_node = source.get_or_create_node_execution("node-source")
source_node.mark_taken()
serialized = source.dumps()
target = GraphExecution(workflow_id="wf-2")
target.start()
target.abort("pre-existing abort")
temp_node = target.get_or_create_node_execution("node-temp")
temp_node.increment_retry()
temp_node.mark_failed("temp error")
# Act
target.loads(serialized)
# Assert
assert target.aborted is False
assert target.error is None
assert target.started is True
assert target.completed is False
assert set(target.node_executions) == {"node-source"}
restored_node = target.node_executions["node-source"]
assert restored_node.state is NodeState.TAKEN
assert restored_node.retry_count == 0
assert restored_node.execution_id is None
assert restored_node.error is None
def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None:
"""ResponseStreamCoordinator serialization restores coordinator internals."""
template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])])
template_secondary = Template(segments=[TextSegment(text="secondary")])
class DummyNode:
def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None:
self.id = node_id
self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM
self.execution_type = execution_type
self.state = NodeState.UNKNOWN
self.title = node_id
self.template = template
def blocks_variable_output(self, *_args) -> bool:
return False
response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE)
response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE)
response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE)
source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE)
class DummyGraph:
def __init__(self) -> None:
self.nodes = {
response_node1.id: response_node1,
response_node2.id: response_node2,
response_node3.id: response_node3,
source_node.id: source_node,
}
self.edges: dict[str, object] = {}
self.root_node = response_node1
def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised
return []
def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised
return []
graph = DummyGraph()
def fake_from_node(cls, node: DummyNode) -> ResponseSession:
return ResponseSession(node_id=node.id, template=node.template)
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
coordinator._response_nodes = {"response-1", "response-2", "response-3"}
coordinator._paths_maps = {
"response-1": [Path(edges=["edge-1"])],
"response-2": [Path(edges=[])],
"response-3": [Path(edges=["edge-2", "edge-3"])],
}
active_session = ResponseSession(node_id="response-1", template=response_node1.template)
active_session.index = 1
coordinator._active_session = active_session
waiting_session = ResponseSession(node_id="response-2", template=response_node2.template)
coordinator._waiting_sessions = deque([waiting_session])
pending_session = ResponseSession(node_id="response-3", template=response_node3.template)
pending_session.index = 2
coordinator._response_sessions = {"response-3": pending_session}
coordinator._node_execution_ids = {"response-1": "exec-1"}
event = NodeRunStreamChunkEvent(
id="exec-1",
node_id="response-1",
node_type=NodeType.ANSWER,
selector=["node-source", "text"],
chunk="chunk-1",
is_final=False,
)
coordinator._stream_buffers = {("node-source", "text"): [event]}
coordinator._stream_positions = {("node-source", "text"): 1}
coordinator._closed_streams = {("node-source", "text")}
serialized = coordinator.dumps()
restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
restored.loads(serialized)
assert restored._response_nodes == {"response-1", "response-2", "response-3"}
assert restored._paths_maps["response-1"][0].edges == ["edge-1"]
assert restored._active_session is not None
assert restored._active_session.node_id == "response-1"
assert restored._active_session.index == 1
waiting_restored = list(restored._waiting_sessions)
assert len(waiting_restored) == 1
assert waiting_restored[0].node_id == "response-2"
assert waiting_restored[0].index == 0
assert set(restored._response_sessions) == {"response-3"}
assert restored._response_sessions["response-3"].index == 2
assert restored._node_execution_ids == {"response-1": "exec-1"}
assert ("node-source", "text") in restored._stream_buffers
restored_event = restored._stream_buffers[("node-source", "text")][0]
assert restored_event.chunk == "chunk-1"
assert restored._stream_positions[("node-source", "text")] == 1
assert ("node-source", "text") in restored._closed_streams

View File

@ -7,7 +7,7 @@ the behavior of mock nodes during testing.
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any
from core.workflow.enums import NodeType
@ -18,9 +18,9 @@ class NodeMockConfig:
node_id: str
outputs: dict[str, Any] = field(default_factory=dict)
error: Optional[str] = None
error: str | None = None
delay: float = 0.0 # Simulated execution delay in seconds
custom_handler: Optional[Callable[..., dict[str, Any]]] = None
custom_handler: Callable[..., dict[str, Any]] | None = None
@dataclass
@ -51,7 +51,7 @@ class MockConfig:
default_template_transform_response: str = "This is mocked template transform output"
default_code_response: dict[str, Any] = field(default_factory=lambda: {"result": "mocked code execution result"})
def get_node_config(self, node_id: str) -> Optional[NodeMockConfig]:
def get_node_config(self, node_id: str) -> NodeMockConfig | None:
"""Get configuration for a specific node."""
return self.node_configs.get(node_id)

View File

@ -64,7 +64,7 @@ class MockNodeMixin:
return default_outputs
def _should_simulate_error(self) -> Optional[str]:
def _should_simulate_error(self) -> str | None:
"""Check if this node should simulate an error."""
if not self.mock_config:
return None
@ -615,18 +615,9 @@ class MockIterationNode(MockNodeMixin, IterationNode):
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=self.graph_config,
graph_runtime_state=graph_runtime_state_copy,
max_execution_steps=10000, # Use default or config value
max_execution_time=600, # Use default or config value
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)
@ -685,18 +676,9 @@ class MockLoopNode(MockNodeMixin, LoopNode):
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
graph_runtime_state=graph_runtime_state_copy,
max_execution_steps=10000, # Use default or config value
max_execution_time=600, # Use default or config value
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)

View File

@ -118,18 +118,9 @@ def test_parallel_streaming_workflow():
# Create the graph engine
engine = GraphEngine(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=30,
command_channel=InMemoryChannel(),
)

View File

@ -17,9 +17,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any, Optional
from typing import Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.utils.yaml_utils import _load_yaml_file
from core.variables import (
ArrayNumberVariable,
@ -42,7 +41,6 @@ from core.workflow.graph_events import (
)
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from .test_mock_config import MockConfig
from .test_mock_factory import MockNodeFactory
@ -60,14 +58,14 @@ class WorkflowTestCase:
query: str = ""
description: str = ""
timeout: float = 30.0
mock_config: Optional[MockConfig] = None
mock_config: MockConfig | None = None
use_auto_mock: bool = False
expected_event_sequence: Optional[Sequence[type[GraphEngineEvent]]] = None
expected_event_sequence: Sequence[type[GraphEngineEvent]] | None = None
tags: list[str] = field(default_factory=list)
skip: bool = False
skip_reason: str = ""
retry_count: int = 0
custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None
custom_validator: Callable[[dict[str, Any]], bool] | None = None
@dataclass
@ -76,14 +74,14 @@ class WorkflowTestResult:
test_case: WorkflowTestCase
success: bool
error: Optional[Exception] = None
actual_outputs: Optional[dict[str, Any]] = None
error: Exception | None = None
actual_outputs: dict[str, Any] | None = None
execution_time: float = 0.0
event_sequence_match: Optional[bool] = None
event_mismatch_details: Optional[str] = None
event_sequence_match: bool | None = None
event_mismatch_details: str | None = None
events: list[GraphEngineEvent] = field(default_factory=list)
retry_attempts: int = 0
validation_details: Optional[str] = None
validation_details: str | None = None
@dataclass
@ -116,7 +114,7 @@ class TestSuiteResult:
class WorkflowRunner:
"""Core workflow execution engine for tests."""
def __init__(self, fixtures_dir: Optional[Path] = None):
def __init__(self, fixtures_dir: Path | None = None):
"""Initialize the workflow runner."""
if fixtures_dir is None:
# Use the new central fixtures location
@ -147,9 +145,9 @@ class WorkflowRunner:
self,
fixture_data: dict[str, Any],
query: str = "",
inputs: Optional[dict[str, Any]] = None,
inputs: dict[str, Any] | None = None,
use_mock_factory: bool = False,
mock_config: Optional[MockConfig] = None,
mock_config: MockConfig | None = None,
) -> tuple[Graph, GraphRuntimeState]:
"""Create a Graph instance from fixture data."""
workflow_config = fixture_data.get("workflow", {})
@ -240,7 +238,7 @@ class TableTestRunner:
def __init__(
self,
fixtures_dir: Optional[Path] = None,
fixtures_dir: Path | None = None,
max_workers: int = 4,
enable_logging: bool = False,
log_level: str = "INFO",
@ -373,23 +371,11 @@ class TableTestRunner:
mock_config=test_case.mock_config,
)
workflow_config = fixture_data.get("workflow", {})
graph_config = workflow_config.get("graph", {})
# Create and run the engine with configured worker settings
engine = GraphEngine(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER, # Use DEBUGGER to avoid conversation_id requirement
call_depth=0,
graph=graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=int(test_case.timeout),
command_channel=InMemoryChannel(),
min_workers=self.graph_engine_min_workers,
max_workers=self.graph_engine_max_workers,
@ -469,8 +455,8 @@ class TableTestRunner:
self,
expected_outputs: dict[str, Any],
actual_outputs: dict[str, Any],
custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None,
) -> tuple[bool, Optional[str]]:
custom_validator: Callable[[dict[str, Any]], bool] | None = None,
) -> tuple[bool, str | None]:
"""
Validate actual outputs against expected outputs.
@ -519,7 +505,7 @@ class TableTestRunner:
def _validate_event_sequence(
self, expected_sequence: list[type[GraphEngineEvent]], actual_events: list[GraphEngineEvent]
) -> tuple[bool, Optional[str]]:
) -> tuple[bool, str | None]:
"""
Validate that actual events match the expected event sequence.
@ -551,7 +537,7 @@ class TableTestRunner:
self,
test_cases: list[WorkflowTestCase],
parallel: bool = False,
tags_filter: Optional[list[str]] = None,
tags_filter: list[str] | None = None,
fail_fast: bool = False,
) -> TestSuiteResult:
"""

View File

@ -1,11 +1,9 @@
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import (
GraphRunSucceededEvent,
NodeRunStreamChunkEvent,
)
from models.enums import UserFrom
from .test_table_runner import TableTestRunner
@ -23,23 +21,11 @@ def test_tool_in_chatflow():
use_mock_factory=True,
)
workflow_config = fixture_data.get("workflow", {})
graph_config = workflow_config.get("graph", {})
# Create and run the engine
engine = GraphEngine(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
graph=graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=30,
command_channel=InMemoryChannel(),
)