mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge remote-tracking branch 'upstream/feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
@ -17,7 +17,6 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
output_schema=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
|
||||
@ -95,17 +95,3 @@ class TestGraphRuntimeState:
|
||||
# Test add_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.add_tokens(-1)
|
||||
|
||||
def test_deep_copy_for_nested_objects(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test deep copy for nested dict
|
||||
nested_data = {"level1": {"level2": {"value": "test"}}}
|
||||
state.set_output("nested", nested_data)
|
||||
|
||||
retrieved = state.get_output("nested")
|
||||
retrieved["level1"]["level2"]["value"] = "modified"
|
||||
|
||||
# Original should remain unchanged
|
||||
assert state.get_output("nested")["level1"]["level2"]["value"] == "test"
|
||||
|
||||
@ -3,14 +3,12 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
def test_abort_command():
|
||||
@ -42,18 +40,9 @@ def test_abort_command():
|
||||
|
||||
# Create GraphEngine with same shared runtime state
|
||||
engine = GraphEngine(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test_workflow",
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=mock_graph,
|
||||
graph_config={},
|
||||
graph_runtime_state=shared_runtime_state, # Use shared instance
|
||||
max_execution_steps=100,
|
||||
max_execution_time=10,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ This test validates that:
|
||||
- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output)
|
||||
"""
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
@ -16,7 +15,6 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .test_table_runner import TableTestRunner
|
||||
|
||||
@ -40,23 +38,11 @@ def test_streaming_output_with_blocking_equals_one():
|
||||
use_mock_factory=True,
|
||||
)
|
||||
|
||||
workflow_config = fixture_data.get("workflow", {})
|
||||
graph_config = workflow_config.get("graph", {})
|
||||
|
||||
# Create and run the engine
|
||||
engine = GraphEngine(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=30,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
@ -147,23 +133,11 @@ def test_streaming_output_with_blocking_not_equals_one():
|
||||
use_mock_factory=True,
|
||||
)
|
||||
|
||||
workflow_config = fixture_data.get("workflow", {})
|
||||
graph_config = workflow_config.get("graph", {})
|
||||
|
||||
# Create and run the engine
|
||||
engine = GraphEngine(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=30,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from flask import Flask, g
|
||||
@ -59,7 +58,7 @@ class TestContextPreservation:
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Variable to store value from worker
|
||||
worker_value: Optional[str] = None
|
||||
worker_value: str | None = None
|
||||
|
||||
def worker_task() -> None:
|
||||
nonlocal worker_value
|
||||
@ -120,7 +119,7 @@ class TestContextPreservation:
|
||||
test_node = MagicMock(spec=Node)
|
||||
|
||||
# Variable to capture context inside node execution
|
||||
captured_value: Optional[str] = None
|
||||
captured_value: str | None = None
|
||||
context_available_in_node = False
|
||||
|
||||
def mock_run() -> list[GraphNodeEventBase]:
|
||||
|
||||
@ -10,11 +10,9 @@ import time
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
|
||||
from models.enums import UserFrom
|
||||
|
||||
# Import the test framework from the new module
|
||||
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
|
||||
@ -460,18 +458,9 @@ def test_layer_system_basic():
|
||||
|
||||
# Create engine with layer
|
||||
engine = GraphEngine(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_config=fixture_data.get("workflow", {}).get("graph", {}),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=300,
|
||||
max_execution_time=60,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
@ -525,18 +514,9 @@ def test_layer_chaining():
|
||||
|
||||
# Create engine
|
||||
engine = GraphEngine(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_config=fixture_data.get("workflow", {}).get("graph", {}),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=300,
|
||||
max_execution_time=60,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
@ -581,18 +561,9 @@ def test_layer_error_handling():
|
||||
|
||||
# Create engine with faulty layer
|
||||
engine = GraphEngine(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_config=fixture_data.get("workflow", {}).get("graph", {}),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=300,
|
||||
max_execution_time=60,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,194 @@
|
||||
"""Unit tests for GraphExecution serialization helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import deque
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.graph_engine.domain import GraphExecution
|
||||
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_engine.response_coordinator.path import Path
|
||||
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
|
||||
|
||||
|
||||
class CustomGraphExecutionError(Exception):
|
||||
"""Custom exception used to verify error serialization."""
|
||||
|
||||
|
||||
def test_graph_execution_serialization_round_trip() -> None:
|
||||
"""GraphExecution serialization restores full aggregate state."""
|
||||
# Arrange
|
||||
execution = GraphExecution(workflow_id="wf-1")
|
||||
execution.start()
|
||||
node_a = execution.get_or_create_node_execution("node-a")
|
||||
node_a.mark_started(execution_id="exec-1")
|
||||
node_a.increment_retry()
|
||||
node_a.mark_failed("boom")
|
||||
node_b = execution.get_or_create_node_execution("node-b")
|
||||
node_b.mark_skipped()
|
||||
execution.fail(CustomGraphExecutionError("serialization failure"))
|
||||
|
||||
# Act
|
||||
serialized = execution.dumps()
|
||||
payload = json.loads(serialized)
|
||||
restored = GraphExecution(workflow_id="wf-1")
|
||||
restored.loads(serialized)
|
||||
|
||||
# Assert
|
||||
assert payload["type"] == "GraphExecution"
|
||||
assert payload["version"] == "1.0"
|
||||
assert restored.workflow_id == "wf-1"
|
||||
assert restored.started is True
|
||||
assert restored.completed is True
|
||||
assert restored.aborted is False
|
||||
assert isinstance(restored.error, CustomGraphExecutionError)
|
||||
assert str(restored.error) == "serialization failure"
|
||||
assert set(restored.node_executions) == {"node-a", "node-b"}
|
||||
restored_node_a = restored.node_executions["node-a"]
|
||||
assert restored_node_a.state is NodeState.TAKEN
|
||||
assert restored_node_a.retry_count == 1
|
||||
assert restored_node_a.execution_id == "exec-1"
|
||||
assert restored_node_a.error == "boom"
|
||||
restored_node_b = restored.node_executions["node-b"]
|
||||
assert restored_node_b.state is NodeState.SKIPPED
|
||||
assert restored_node_b.retry_count == 0
|
||||
assert restored_node_b.execution_id is None
|
||||
assert restored_node_b.error is None
|
||||
|
||||
|
||||
def test_graph_execution_loads_replaces_existing_state() -> None:
|
||||
"""loads replaces existing runtime data with serialized snapshot."""
|
||||
# Arrange
|
||||
source = GraphExecution(workflow_id="wf-2")
|
||||
source.start()
|
||||
source_node = source.get_or_create_node_execution("node-source")
|
||||
source_node.mark_taken()
|
||||
serialized = source.dumps()
|
||||
|
||||
target = GraphExecution(workflow_id="wf-2")
|
||||
target.start()
|
||||
target.abort("pre-existing abort")
|
||||
temp_node = target.get_or_create_node_execution("node-temp")
|
||||
temp_node.increment_retry()
|
||||
temp_node.mark_failed("temp error")
|
||||
|
||||
# Act
|
||||
target.loads(serialized)
|
||||
|
||||
# Assert
|
||||
assert target.aborted is False
|
||||
assert target.error is None
|
||||
assert target.started is True
|
||||
assert target.completed is False
|
||||
assert set(target.node_executions) == {"node-source"}
|
||||
restored_node = target.node_executions["node-source"]
|
||||
assert restored_node.state is NodeState.TAKEN
|
||||
assert restored_node.retry_count == 0
|
||||
assert restored_node.execution_id is None
|
||||
assert restored_node.error is None
|
||||
|
||||
|
||||
def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None:
|
||||
"""ResponseStreamCoordinator serialization restores coordinator internals."""
|
||||
|
||||
template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])])
|
||||
template_secondary = Template(segments=[TextSegment(text="secondary")])
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None:
|
||||
self.id = node_id
|
||||
self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM
|
||||
self.execution_type = execution_type
|
||||
self.state = NodeState.UNKNOWN
|
||||
self.title = node_id
|
||||
self.template = template
|
||||
|
||||
def blocks_variable_output(self, *_args) -> bool:
|
||||
return False
|
||||
|
||||
response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE)
|
||||
response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE)
|
||||
response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE)
|
||||
source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE)
|
||||
|
||||
class DummyGraph:
|
||||
def __init__(self) -> None:
|
||||
self.nodes = {
|
||||
response_node1.id: response_node1,
|
||||
response_node2.id: response_node2,
|
||||
response_node3.id: response_node3,
|
||||
source_node.id: source_node,
|
||||
}
|
||||
self.edges: dict[str, object] = {}
|
||||
self.root_node = response_node1
|
||||
|
||||
def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised
|
||||
return []
|
||||
|
||||
def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised
|
||||
return []
|
||||
|
||||
graph = DummyGraph()
|
||||
|
||||
def fake_from_node(cls, node: DummyNode) -> ResponseSession:
|
||||
return ResponseSession(node_id=node.id, template=node.template)
|
||||
|
||||
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
|
||||
coordinator._response_nodes = {"response-1", "response-2", "response-3"}
|
||||
coordinator._paths_maps = {
|
||||
"response-1": [Path(edges=["edge-1"])],
|
||||
"response-2": [Path(edges=[])],
|
||||
"response-3": [Path(edges=["edge-2", "edge-3"])],
|
||||
}
|
||||
|
||||
active_session = ResponseSession(node_id="response-1", template=response_node1.template)
|
||||
active_session.index = 1
|
||||
coordinator._active_session = active_session
|
||||
waiting_session = ResponseSession(node_id="response-2", template=response_node2.template)
|
||||
coordinator._waiting_sessions = deque([waiting_session])
|
||||
pending_session = ResponseSession(node_id="response-3", template=response_node3.template)
|
||||
pending_session.index = 2
|
||||
coordinator._response_sessions = {"response-3": pending_session}
|
||||
|
||||
coordinator._node_execution_ids = {"response-1": "exec-1"}
|
||||
event = NodeRunStreamChunkEvent(
|
||||
id="exec-1",
|
||||
node_id="response-1",
|
||||
node_type=NodeType.ANSWER,
|
||||
selector=["node-source", "text"],
|
||||
chunk="chunk-1",
|
||||
is_final=False,
|
||||
)
|
||||
coordinator._stream_buffers = {("node-source", "text"): [event]}
|
||||
coordinator._stream_positions = {("node-source", "text"): 1}
|
||||
coordinator._closed_streams = {("node-source", "text")}
|
||||
|
||||
serialized = coordinator.dumps()
|
||||
|
||||
restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
|
||||
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
|
||||
restored.loads(serialized)
|
||||
|
||||
assert restored._response_nodes == {"response-1", "response-2", "response-3"}
|
||||
assert restored._paths_maps["response-1"][0].edges == ["edge-1"]
|
||||
assert restored._active_session is not None
|
||||
assert restored._active_session.node_id == "response-1"
|
||||
assert restored._active_session.index == 1
|
||||
waiting_restored = list(restored._waiting_sessions)
|
||||
assert len(waiting_restored) == 1
|
||||
assert waiting_restored[0].node_id == "response-2"
|
||||
assert waiting_restored[0].index == 0
|
||||
assert set(restored._response_sessions) == {"response-3"}
|
||||
assert restored._response_sessions["response-3"].index == 2
|
||||
assert restored._node_execution_ids == {"response-1": "exec-1"}
|
||||
assert ("node-source", "text") in restored._stream_buffers
|
||||
restored_event = restored._stream_buffers[("node-source", "text")][0]
|
||||
assert restored_event.chunk == "chunk-1"
|
||||
assert restored._stream_positions[("node-source", "text")] == 1
|
||||
assert ("node-source", "text") in restored._closed_streams
|
||||
@ -7,7 +7,7 @@ the behavior of mock nodes during testing.
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
@ -18,9 +18,9 @@ class NodeMockConfig:
|
||||
|
||||
node_id: str
|
||||
outputs: dict[str, Any] = field(default_factory=dict)
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
delay: float = 0.0 # Simulated execution delay in seconds
|
||||
custom_handler: Optional[Callable[..., dict[str, Any]]] = None
|
||||
custom_handler: Callable[..., dict[str, Any]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -51,7 +51,7 @@ class MockConfig:
|
||||
default_template_transform_response: str = "This is mocked template transform output"
|
||||
default_code_response: dict[str, Any] = field(default_factory=lambda: {"result": "mocked code execution result"})
|
||||
|
||||
def get_node_config(self, node_id: str) -> Optional[NodeMockConfig]:
|
||||
def get_node_config(self, node_id: str) -> NodeMockConfig | None:
|
||||
"""Get configuration for a specific node."""
|
||||
return self.node_configs.get(node_id)
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ class MockNodeMixin:
|
||||
|
||||
return default_outputs
|
||||
|
||||
def _should_simulate_error(self) -> Optional[str]:
|
||||
def _should_simulate_error(self) -> str | None:
|
||||
"""Check if this node should simulate an error."""
|
||||
if not self.mock_config:
|
||||
return None
|
||||
@ -615,18 +615,9 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=iteration_graph,
|
||||
graph_config=self.graph_config,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
max_execution_steps=10000, # Use default or config value
|
||||
max_execution_time=600, # Use default or config value
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
@ -685,18 +676,9 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=loop_graph,
|
||||
graph_config=self.graph_config,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
max_execution_steps=10000, # Use default or config value
|
||||
max_execution_time=600, # Use default or config value
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
|
||||
@ -118,18 +118,9 @@ def test_parallel_streaming_workflow():
|
||||
|
||||
# Create the graph engine
|
||||
engine = GraphEngine(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=30,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
|
||||
@ -17,9 +17,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.utils.yaml_utils import _load_yaml_file
|
||||
from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
@ -42,7 +41,6 @@ from core.workflow.graph_events import (
|
||||
)
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
@ -60,14 +58,14 @@ class WorkflowTestCase:
|
||||
query: str = ""
|
||||
description: str = ""
|
||||
timeout: float = 30.0
|
||||
mock_config: Optional[MockConfig] = None
|
||||
mock_config: MockConfig | None = None
|
||||
use_auto_mock: bool = False
|
||||
expected_event_sequence: Optional[Sequence[type[GraphEngineEvent]]] = None
|
||||
expected_event_sequence: Sequence[type[GraphEngineEvent]] | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
skip: bool = False
|
||||
skip_reason: str = ""
|
||||
retry_count: int = 0
|
||||
custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None
|
||||
custom_validator: Callable[[dict[str, Any]], bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -76,14 +74,14 @@ class WorkflowTestResult:
|
||||
|
||||
test_case: WorkflowTestCase
|
||||
success: bool
|
||||
error: Optional[Exception] = None
|
||||
actual_outputs: Optional[dict[str, Any]] = None
|
||||
error: Exception | None = None
|
||||
actual_outputs: dict[str, Any] | None = None
|
||||
execution_time: float = 0.0
|
||||
event_sequence_match: Optional[bool] = None
|
||||
event_mismatch_details: Optional[str] = None
|
||||
event_sequence_match: bool | None = None
|
||||
event_mismatch_details: str | None = None
|
||||
events: list[GraphEngineEvent] = field(default_factory=list)
|
||||
retry_attempts: int = 0
|
||||
validation_details: Optional[str] = None
|
||||
validation_details: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -116,7 +114,7 @@ class TestSuiteResult:
|
||||
class WorkflowRunner:
|
||||
"""Core workflow execution engine for tests."""
|
||||
|
||||
def __init__(self, fixtures_dir: Optional[Path] = None):
|
||||
def __init__(self, fixtures_dir: Path | None = None):
|
||||
"""Initialize the workflow runner."""
|
||||
if fixtures_dir is None:
|
||||
# Use the new central fixtures location
|
||||
@ -147,9 +145,9 @@ class WorkflowRunner:
|
||||
self,
|
||||
fixture_data: dict[str, Any],
|
||||
query: str = "",
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
use_mock_factory: bool = False,
|
||||
mock_config: Optional[MockConfig] = None,
|
||||
mock_config: MockConfig | None = None,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
"""Create a Graph instance from fixture data."""
|
||||
workflow_config = fixture_data.get("workflow", {})
|
||||
@ -240,7 +238,7 @@ class TableTestRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fixtures_dir: Optional[Path] = None,
|
||||
fixtures_dir: Path | None = None,
|
||||
max_workers: int = 4,
|
||||
enable_logging: bool = False,
|
||||
log_level: str = "INFO",
|
||||
@ -373,23 +371,11 @@ class TableTestRunner:
|
||||
mock_config=test_case.mock_config,
|
||||
)
|
||||
|
||||
workflow_config = fixture_data.get("workflow", {})
|
||||
graph_config = workflow_config.get("graph", {})
|
||||
|
||||
# Create and run the engine with configured worker settings
|
||||
engine = GraphEngine(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER, # Use DEBUGGER to avoid conversation_id requirement
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=int(test_case.timeout),
|
||||
command_channel=InMemoryChannel(),
|
||||
min_workers=self.graph_engine_min_workers,
|
||||
max_workers=self.graph_engine_max_workers,
|
||||
@ -469,8 +455,8 @@ class TableTestRunner:
|
||||
self,
|
||||
expected_outputs: dict[str, Any],
|
||||
actual_outputs: dict[str, Any],
|
||||
custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None,
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
custom_validator: Callable[[dict[str, Any]], bool] | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate actual outputs against expected outputs.
|
||||
|
||||
@ -519,7 +505,7 @@ class TableTestRunner:
|
||||
|
||||
def _validate_event_sequence(
|
||||
self, expected_sequence: list[type[GraphEngineEvent]], actual_events: list[GraphEngineEvent]
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate that actual events match the expected event sequence.
|
||||
|
||||
@ -551,7 +537,7 @@ class TableTestRunner:
|
||||
self,
|
||||
test_cases: list[WorkflowTestCase],
|
||||
parallel: bool = False,
|
||||
tags_filter: Optional[list[str]] = None,
|
||||
tags_filter: list[str] | None = None,
|
||||
fail_fast: bool = False,
|
||||
) -> TestSuiteResult:
|
||||
"""
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .test_table_runner import TableTestRunner
|
||||
|
||||
@ -23,23 +21,11 @@ def test_tool_in_chatflow():
|
||||
use_mock_factory=True,
|
||||
)
|
||||
|
||||
workflow_config = fixture_data.get("workflow", {})
|
||||
graph_config = workflow_config.get("graph", {})
|
||||
|
||||
# Create and run the engine
|
||||
engine = GraphEngine(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=30,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user