mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
fix(api): ensure node and edge states are properly persisted while pausing
This commit is contained in:
@ -0,0 +1,189 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_llm_node(
|
||||
*,
|
||||
node_id: str,
|
||||
runtime_state: GraphRuntimeState,
|
||||
graph_init_params: GraphInitParams,
|
||||
mock_config: MockConfig,
|
||||
) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=f"LLM {node_id}",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=f"Prompt {node_id}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
llm_config = {"id": node_id, "data": llm_data.model_dump()}
|
||||
return MockLLMNode(
|
||||
id=llm_config["id"],
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
mock_config = MockConfig()
|
||||
llm_a = _build_llm_node(
|
||||
node_id="llm_a",
|
||||
runtime_state=runtime_state,
|
||||
graph_init_params=graph_init_params,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_b = _build_llm_node(
|
||||
node_id="llm_b",
|
||||
runtime_state=runtime_state,
|
||||
graph_init_params=graph_init_params,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(title="End", outputs=[], desc=None)
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
id=end_config["id"],
|
||||
config=end_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
builder = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_a, from_node_id="start")
|
||||
.add_node(llm_b, from_node_id="start")
|
||||
.add_node(end_node, from_node_id="llm_a")
|
||||
)
|
||||
return builder.connect(tail="llm_b", head="end").build()
|
||||
|
||||
|
||||
def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]:
|
||||
return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()}
|
||||
|
||||
|
||||
def test_runtime_state_snapshot_restores_graph_states() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
graph = _build_graph(runtime_state)
|
||||
runtime_state.attach_graph(graph)
|
||||
|
||||
graph.nodes["llm_a"].state = NodeState.TAKEN
|
||||
graph.nodes["llm_b"].state = NodeState.SKIPPED
|
||||
|
||||
for edge in graph.edges.values():
|
||||
if edge.tail == "start" and edge.head == "llm_a":
|
||||
edge.state = NodeState.TAKEN
|
||||
elif edge.tail == "start" and edge.head == "llm_b":
|
||||
edge.state = NodeState.SKIPPED
|
||||
elif edge.head == "end" and edge.tail == "llm_a":
|
||||
edge.state = NodeState.TAKEN
|
||||
elif edge.head == "end" and edge.tail == "llm_b":
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resumed_graph = _build_graph(resumed_state)
|
||||
resumed_state.attach_graph(resumed_graph)
|
||||
|
||||
assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN
|
||||
assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED
|
||||
assert _edge_state_map(resumed_graph) == _edge_state_map(graph)
|
||||
|
||||
|
||||
def test_join_readiness_uses_restored_edge_states() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
graph = _build_graph(runtime_state)
|
||||
runtime_state.attach_graph(graph)
|
||||
|
||||
ready_queue = InMemoryReadyQueue()
|
||||
state_manager = GraphStateManager(graph, ready_queue)
|
||||
|
||||
for edge in graph.get_incoming_edges("end"):
|
||||
if edge.tail == "llm_a":
|
||||
edge.state = NodeState.TAKEN
|
||||
if edge.tail == "llm_b":
|
||||
edge.state = NodeState.UNKNOWN
|
||||
|
||||
assert state_manager.is_node_ready("end") is False
|
||||
|
||||
for edge in graph.get_incoming_edges("end"):
|
||||
if edge.tail == "llm_b":
|
||||
edge.state = NodeState.TAKEN
|
||||
|
||||
assert state_manager.is_node_ready("end") is True
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resumed_graph = _build_graph(resumed_state)
|
||||
resumed_state.attach_graph(resumed_graph)
|
||||
|
||||
resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue())
|
||||
assert resumed_state_manager.is_node_ready("end") is True
|
||||
@ -0,0 +1,269 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class PauseStateStore(Protocol):
|
||||
def save(self, runtime_state: GraphRuntimeState) -> None: ...
|
||||
|
||||
def load(self) -> GraphRuntimeState: ...
|
||||
|
||||
|
||||
class InMemoryPauseStore:
|
||||
def __init__(self) -> None:
|
||||
self._snapshot: str | None = None
|
||||
|
||||
def save(self, runtime_state: GraphRuntimeState) -> None:
|
||||
self._snapshot = runtime_state.dumps()
|
||||
|
||||
def load(self) -> GraphRuntimeState:
|
||||
assert self._snapshot is not None
|
||||
return GraphRuntimeState.from_snapshot(self._snapshot)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticForm(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
is_submitted: bool
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return "token"
|
||||
|
||||
@property
|
||||
def recipients(self) -> list:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class StaticRepo(HumanInputFormRepository):
|
||||
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
|
||||
self._forms_by_node_id = dict(forms_by_node_id)
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
return self._forms_by_node_id.get(node_id)
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
raise AssertionError("create_form should not be called in resume scenario")
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
|
||||
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
|
||||
human_a = HumanInputNode(
|
||||
id=human_a_config["id"],
|
||||
config=human_a_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
|
||||
human_b = HumanInputNode(
|
||||
id=human_b_config["id"],
|
||||
config=human_b_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="End",
|
||||
outputs=[
|
||||
OutputVariableEntity(variable="res_a", value_selector=["human_a", "__action_id"]),
|
||||
OutputVariableEntity(variable="res_b", value_selector=["human_b", "__action_id"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
id=end_config["id"],
|
||||
config=end_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
builder = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_a, from_node_id="start")
|
||||
.add_node(human_b, from_node_id="start")
|
||||
.add_node(end_node, from_node_id="human_a", source_handle="approve")
|
||||
)
|
||||
return builder.connect(tail="human_b", head="end", source_handle="approve").build()
|
||||
|
||||
|
||||
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[object]:
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
)
|
||||
return list(engine.run())
|
||||
|
||||
|
||||
def _form(submitted: bool, action_id: str | None) -> StaticForm:
|
||||
return StaticForm(
|
||||
form_id="form",
|
||||
rendered="rendered",
|
||||
is_submitted=submitted,
|
||||
action_id=action_id,
|
||||
data={},
|
||||
status_value=HumanInputFormStatus.SUBMITTED if submitted else HumanInputFormStatus.WAITING,
|
||||
)
|
||||
|
||||
|
||||
def test_parallel_human_input_join_completes_after_second_resume() -> None:
|
||||
pause_store: PauseStateStore = InMemoryPauseStore()
|
||||
|
||||
initial_state = _build_runtime_state()
|
||||
initial_repo = StaticRepo(
|
||||
{
|
||||
"human_a": _form(submitted=False, action_id=None),
|
||||
"human_b": _form(submitted=False, action_id=None),
|
||||
}
|
||||
)
|
||||
initial_graph = _build_graph(initial_state, initial_repo)
|
||||
initial_events = _run_graph(initial_graph, initial_state)
|
||||
|
||||
assert isinstance(initial_events[-1], GraphRunPausedEvent)
|
||||
pause_store.save(initial_state)
|
||||
|
||||
first_resume_state = pause_store.load()
|
||||
first_resume_repo = StaticRepo(
|
||||
{
|
||||
"human_a": _form(submitted=True, action_id="approve"),
|
||||
"human_b": _form(submitted=False, action_id=None),
|
||||
}
|
||||
)
|
||||
first_resume_graph = _build_graph(first_resume_state, first_resume_repo)
|
||||
first_resume_events = _run_graph(first_resume_graph, first_resume_state)
|
||||
|
||||
assert isinstance(first_resume_events[0], GraphRunStartedEvent)
|
||||
assert first_resume_events[0].reason is WorkflowStartReason.RESUMPTION
|
||||
assert isinstance(first_resume_events[-1], GraphRunPausedEvent)
|
||||
pause_store.save(first_resume_state)
|
||||
|
||||
second_resume_state = pause_store.load()
|
||||
second_resume_repo = StaticRepo(
|
||||
{
|
||||
"human_a": _form(submitted=True, action_id="approve"),
|
||||
"human_b": _form(submitted=True, action_id="approve"),
|
||||
}
|
||||
)
|
||||
second_resume_graph = _build_graph(second_resume_state, second_resume_repo)
|
||||
second_resume_events = _run_graph(second_resume_graph, second_resume_state)
|
||||
|
||||
assert isinstance(second_resume_events[0], GraphRunStartedEvent)
|
||||
assert second_resume_events[0].reason is WorkflowStartReason.RESUMPTION
|
||||
assert isinstance(second_resume_events[-1], GraphRunSucceededEvent)
|
||||
assert any(
|
||||
isinstance(event, NodeRunSucceededEvent) and event.node_id == "end" for event in second_resume_events
|
||||
)
|
||||
Reference in New Issue
Block a user