test(api): fix broken tests

This commit is contained in:
QuantumGhost
2026-01-04 23:23:58 +08:00
parent e6eb879c61
commit 77dc8a6edb
6 changed files with 48 additions and 9 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
@ -114,6 +114,7 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
entity.data = form_data or {}
entity.is_submitted = True
entity.status_value = HumanInputFormStatus.SUBMITTED
entity.expiration = naive_utc_now() + timedelta(days=1)
def clear_submission(self) -> None:
if not self.created_forms:

View File

@ -1,3 +1,4 @@
import datetime
import time
from collections.abc import Iterable
from unittest.mock import MagicMock
@ -15,6 +16,7 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
@ -32,6 +34,7 @@ from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
@ -272,12 +275,14 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
expected_pre_chunk_events_in_resumption = [
GraphRunStartedEvent,
NodeRunStartedEvent,
NodeRunHumanInputFormFilledEvent,
]
expected_resume_sequence: list[type] = (
[
GraphRunStartedEvent,
NodeRunStartedEvent,
]
expected_pre_chunk_events_in_resumption
+ [NodeRunStreamChunkEvent] * pre_chunk_count
+ [
NodeRunSucceededEvent,
@ -301,6 +306,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
submitted_form.submitted = True
submitted_form.selected_action_id = scenario["handle"]
submitted_form.submitted_data = {}
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory(
@ -353,7 +359,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
]
assert pre_indices == list(range(2, 2 + pre_chunk_count))
expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption)
assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index))
resume_chunk_indices = [
index

View File

@ -1,3 +1,4 @@
import datetime
import time
from unittest.mock import MagicMock
@ -14,6 +15,7 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
@ -31,6 +33,7 @@ from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
@ -235,6 +238,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
expected_resume_sequence: list[type] = [
GraphRunStartedEvent, # resumed graph run begins
NodeRunStartedEvent, # human node restarts
# Form Filled should be generated first, then the node execution ends and stream chunk is generated.
NodeRunHumanInputFormFilledEvent,
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
NodeRunStreamChunkEvent, # cached llm_initial final chunk
@ -259,6 +264,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
submitted_form.submitted = True
submitted_form.selected_action_id = "accept"
submitted_form.submitted_data = {}
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:

View File

@ -1,3 +1,4 @@
import datetime
import time
from typing import Any
from unittest.mock import MagicMock
@ -10,6 +11,7 @@ from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.graph import GraphRunStartedEvent
@ -26,6 +28,7 @@ from core.workflow.repositories.human_input_form_repository import (
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
def _build_runtime_state() -> GraphRuntimeState:
@ -52,6 +55,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
form_entity.submitted = True
form_entity.selected_action_id = action_id
form_entity.submitted_data = {}
form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
repo.get_form.return_value = form_entity
return repo
@ -146,6 +150,13 @@ def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)]
def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None:
for event in events:
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
return event
return None
def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any:
segment = variable_pool.get(selector)
assert segment is not None
@ -191,6 +202,12 @@ def test_engine_resume_restores_state_and_completion():
combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events)
assert combined_success_nodes == baseline_success_nodes
paused_human_started = _node_start_event(paused_events, "human")
resumed_human_started = _node_start_event(resumed_events, "human")
assert paused_human_started is not None
assert resumed_human_started is not None
assert paused_human_started.id == resumed_human_started.id
assert baseline_state.outputs == resumed_state.outputs
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
resumed_state.variable_pool, ("human", "__action_id")

View File

@ -10,6 +10,7 @@ from pydantic import ValidationError
from core.workflow.entities import GraphInitParams
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.node_events.node import StreamCompletedEvent
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
@ -435,6 +436,8 @@ class TestHumanInputNodeRenderedContent:
form_repository.set_submission(action_id="approve", form_data={"name": "Alice"})
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.outputs["__rendered_content"] == "Name: Alice"
events = list(node._run())
last_event = events[-1]
assert isinstance(last_event, StreamCompletedEvent)
node_run_result = last_event.node_run_result
assert node_run_result.outputs["__rendered_content"] == "Name: Alice"

View File

@ -1,3 +1,4 @@
import datetime
import uuid
from types import SimpleNamespace
@ -8,9 +9,11 @@ from core.workflow.graph_events import (
NodeRunHumanInputFormFilledEvent,
NodeRunStartedEvent,
)
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
@ -69,6 +72,8 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
submitted=True,
selected_action_id="Accept",
submitted_data={"name": "Alice"},
status=HumanInputFormStatus.SUBMITTED,
expiration_time=naive_utc_now() + datetime.timedelta(days=1),
)
repo = _FakeFormRepository(fake_form)