mirror of
https://github.com/langgenius/dify.git
synced 2026-03-03 14:56:20 +08:00
test(api): fix broken tests
This commit is contained in:
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
@ -114,6 +114,7 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
entity.data = form_data or {}
|
||||
entity.is_submitted = True
|
||||
entity.status_value = HumanInputFormStatus.SUBMITTED
|
||||
entity.expiration = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
def clear_submission(self) -> None:
|
||||
if not self.created_forms:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from unittest.mock import MagicMock
|
||||
@ -15,6 +16,7 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
@ -32,6 +34,7 @@ from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
@ -272,12 +275,14 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
|
||||
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
|
||||
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
|
||||
expected_pre_chunk_events_in_resumption = [
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
]
|
||||
|
||||
expected_resume_sequence: list[type] = (
|
||||
[
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
]
|
||||
expected_pre_chunk_events_in_resumption
|
||||
+ [NodeRunStreamChunkEvent] * pre_chunk_count
|
||||
+ [
|
||||
NodeRunSucceededEvent,
|
||||
@ -301,6 +306,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
submitted_form.submitted = True
|
||||
submitted_form.selected_action_id = scenario["handle"]
|
||||
submitted_form.submitted_data = {}
|
||||
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
|
||||
mock_get_repo.get_form.return_value = submitted_form
|
||||
|
||||
def resume_graph_factory(
|
||||
@ -353,7 +359,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
|
||||
]
|
||||
assert pre_indices == list(range(2, 2 + pre_chunk_count))
|
||||
expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption)
|
||||
assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index))
|
||||
|
||||
resume_chunk_indices = [
|
||||
index
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -14,6 +15,7 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
@ -31,6 +33,7 @@ from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
@ -235,6 +238,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
expected_resume_sequence: list[type] = [
|
||||
GraphRunStartedEvent, # resumed graph run begins
|
||||
NodeRunStartedEvent, # human node restarts
|
||||
# Form Filled should be generated first, then the node execution ends and stream chunk is generated.
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
|
||||
NodeRunStreamChunkEvent, # cached llm_initial final chunk
|
||||
@ -259,6 +264,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
submitted_form.submitted = True
|
||||
submitted_form.selected_action_id = "accept"
|
||||
submitted_form.submitted_data = {}
|
||||
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
|
||||
mock_get_repo.get_form.return_value = submitted_form
|
||||
|
||||
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
@ -10,6 +11,7 @@ from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.graph import GraphRunStartedEvent
|
||||
@ -26,6 +28,7 @@ from core.workflow.repositories.human_input_form_repository import (
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
@ -52,6 +55,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
|
||||
form_entity.submitted = True
|
||||
form_entity.selected_action_id = action_id
|
||||
form_entity.submitted_data = {}
|
||||
form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
|
||||
repo.get_form.return_value = form_entity
|
||||
return repo
|
||||
|
||||
@ -146,6 +150,13 @@ def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
|
||||
return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)]
|
||||
|
||||
|
||||
def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None:
|
||||
for event in events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
|
||||
return event
|
||||
return None
|
||||
|
||||
|
||||
def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any:
|
||||
segment = variable_pool.get(selector)
|
||||
assert segment is not None
|
||||
@ -191,6 +202,12 @@ def test_engine_resume_restores_state_and_completion():
|
||||
combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events)
|
||||
assert combined_success_nodes == baseline_success_nodes
|
||||
|
||||
paused_human_started = _node_start_event(paused_events, "human")
|
||||
resumed_human_started = _node_start_event(resumed_events, "human")
|
||||
assert paused_human_started is not None
|
||||
assert resumed_human_started is not None
|
||||
assert paused_human_started.id == resumed_human_started.id
|
||||
|
||||
assert baseline_state.outputs == resumed_state.outputs
|
||||
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
|
||||
resumed_state.variable_pool, ("human", "__action_id")
|
||||
|
||||
@ -10,6 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events.node import StreamCompletedEvent
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
@ -435,6 +436,8 @@ class TestHumanInputNodeRenderedContent:
|
||||
|
||||
form_repository.set_submission(action_id="approve", form_data={"name": "Alice"})
|
||||
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.outputs["__rendered_content"] == "Name: Alice"
|
||||
events = list(node._run())
|
||||
last_event = events[-1]
|
||||
assert isinstance(last_event, StreamCompletedEvent)
|
||||
node_run_result = last_event.node_run_result
|
||||
assert node_run_result.outputs["__rendered_content"] == "Name: Alice"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
|
||||
@ -8,9 +9,11 @@ from core.workflow.graph_events import (
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@ -69,6 +72,8 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
|
||||
submitted=True,
|
||||
selected_action_id="Accept",
|
||||
submitted_data={"name": "Alice"},
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=naive_utc_now() + datetime.timedelta(days=1),
|
||||
)
|
||||
|
||||
repo = _FakeFormRepository(fake_form)
|
||||
|
||||
Reference in New Issue
Block a user