feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@ -0,0 +1,187 @@
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from types import SimpleNamespace
from unittest import mock
import pytest
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
from core.workflow.entities.pause_reason import HumanInputRequired
from models.enums import MessageStatus
from models.execution_extra_content import HumanInputContent
from models.model import EndUser
def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline:
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__(
pipeline_module.AdvancedChatAppGenerateTaskPipeline
)
pipeline._workflow_run_id = "run-1"
pipeline._message_id = "message-1"
pipeline._workflow_tenant_id = "tenant-1"
return pipeline
def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None:
pipeline = _build_pipeline()
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
captured_session: dict[str, mock.Mock] = {}
@contextmanager
def fake_session():
session = mock.Mock()
session.scalar.return_value = None
captured_session["session"] = session
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
pipeline._persist_human_input_extra_content(node_id="node-1")
session = captured_session["session"]
session.add.assert_called_once()
content = session.add.call_args.args[0]
assert isinstance(content, HumanInputContent)
assert content.workflow_run_id == "run-1"
assert content.message_id == "message-1"
assert content.form_id == "form-1"
def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
pipeline = _build_pipeline()
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None)
called = {"value": False}
@contextmanager
def fake_session():
called["value"] = True
session = mock.Mock()
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
pipeline._persist_human_input_extra_content(node_id="node-1")
assert called["value"] is False
def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None:
pipeline = _build_pipeline()
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
captured_session: dict[str, mock.Mock] = {}
@contextmanager
def fake_session():
session = mock.Mock()
session.scalar.return_value = HumanInputContent(
workflow_run_id="run-1",
message_id="message-1",
form_id="form-1",
)
captured_session["session"] = session
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
pipeline._persist_human_input_extra_content(node_id="node-1")
session = captured_session["session"]
session.add.assert_not_called()
def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None:
pipeline = _build_pipeline()
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
pipeline._workflow_response_converter = mock.Mock()
pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = []
pipeline._ensure_graph_runtime_initialized = mock.Mock(
return_value=SimpleNamespace(
total_tokens=0,
node_run_steps=0,
),
)
pipeline._save_message = mock.Mock()
message = SimpleNamespace(status=MessageStatus.NORMAL)
pipeline._get_message = mock.Mock(return_value=message)
pipeline._persist_human_input_extra_content = mock.Mock()
pipeline._base_task_pipeline = mock.Mock()
pipeline._base_task_pipeline.queue_manager = mock.Mock()
pipeline._message_saved_on_pause = False
@contextmanager
def fake_session():
session = mock.Mock()
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
reason = HumanInputRequired(
form_id="form-1",
form_content="content",
inputs=[],
actions=[],
node_id="node-1",
node_title="Approval",
form_token="token-1",
resolved_default_values={},
)
event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"])
list(pipeline._handle_workflow_paused_event(event))
pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1")
assert message.status == MessageStatus.PAUSED
def test_resume_appends_chunks_to_paused_answer() -> None:
app_config = SimpleNamespace(app_id="app-1", tenant_id="tenant-1", sensitive_word_avoidance=None)
application_generate_entity = SimpleNamespace(
app_config=app_config,
files=[],
workflow_run_id="run-1",
query="hello",
invoke_from=InvokeFrom.WEB_APP,
inputs={},
task_id="task-1",
)
queue_manager = SimpleNamespace(graph_runtime_state=None)
conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat")
message = SimpleNamespace(
id="message-1",
created_at=datetime(2024, 1, 1),
query="hello",
answer="before",
status=MessageStatus.PAUSED,
)
user = EndUser()
user.id = "user-1"
user.session_id = "session-1"
workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={})
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=True,
dialogue_count=1,
draft_var_saver_factory=SimpleNamespace(),
)
pipeline._get_message = mock.Mock(return_value=message)
pipeline._recorded_files = []
list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after")))
pipeline._save_message(session=mock.Mock())
assert message.answer == "beforeafter"
assert message.status == MessageStatus.NORMAL

View File

@ -0,0 +1,87 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
def _build_converter():
system_variables = SystemVariable(
files=[],
user_id="user-1",
app_id="app-1",
workflow_id="wf-1",
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
app_entity = SimpleNamespace(
task_id="task-1",
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
invoke_from=InvokeFrom.EXPLORE,
files=[],
inputs={},
workflow_execution_id="run-1",
call_depth=0,
)
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
return WorkflowResponseConverter(
application_generate_entity=app_entity,
user=account,
system_variables=system_variables,
)
def test_human_input_form_filled_stream_response_contains_rendered_content():
converter = _build_converter()
converter.workflow_start_to_stream_response(
task_id="task-1",
workflow_run_id="run-1",
workflow_id="wf-1",
reason=WorkflowStartReason.INITIAL,
)
queue_event = QueueHumanInputFormFilledEvent(
node_execution_id="exec-1",
node_id="node-1",
node_type="human-input",
node_title="Human Input",
rendered_content="# Title\nvalue",
action_id="Approve",
action_text="Approve",
)
resp = converter.human_input_form_filled_to_stream_response(event=queue_event, task_id="task-1")
assert resp.workflow_run_id == "run-1"
assert resp.data.node_id == "node-1"
assert resp.data.node_title == "Human Input"
assert resp.data.rendered_content.startswith("# Title")
assert resp.data.action_id == "Approve"
def test_human_input_form_timeout_stream_response_contains_timeout_metadata():
converter = _build_converter()
converter.workflow_start_to_stream_response(
task_id="task-1",
workflow_run_id="run-1",
workflow_id="wf-1",
reason=WorkflowStartReason.INITIAL,
)
queue_event = QueueHumanInputFormTimeoutEvent(
node_id="node-1",
node_type="human-input",
node_title="Human Input",
expiration_time=datetime(2025, 1, 1, tzinfo=UTC),
)
resp = converter.human_input_form_timeout_to_stream_response(event=queue_event, task_id="task-1")
assert resp.workflow_run_id == "run-1"
assert resp.data.node_id == "node-1"
assert resp.data.node_title == "Human Input"
assert resp.data.expiration_time == 1735689600

View File

@ -0,0 +1,56 @@
from types import SimpleNamespace
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
def _build_converter() -> WorkflowResponseConverter:
"""Construct a minimal WorkflowResponseConverter for testing."""
system_variables = SystemVariable(
files=[],
user_id="user-1",
app_id="app-1",
workflow_id="wf-1",
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
app_entity = SimpleNamespace(
task_id="task-1",
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
invoke_from=InvokeFrom.EXPLORE,
files=[],
inputs={},
workflow_execution_id="run-1",
call_depth=0,
)
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
return WorkflowResponseConverter(
application_generate_entity=app_entity,
user=account,
system_variables=system_variables,
)
def test_workflow_start_stream_response_carries_resumption_reason():
converter = _build_converter()
resp = converter.workflow_start_to_stream_response(
task_id="task-1",
workflow_run_id="run-1",
workflow_id="wf-1",
reason=WorkflowStartReason.RESUMPTION,
)
assert resp.data.reason is WorkflowStartReason.RESUMPTION
def test_workflow_start_stream_response_carries_initial_reason():
converter = _build_converter()
resp = converter.workflow_start_to_stream_response(
task_id="task-1",
workflow_run_id="run-1",
workflow_id="wf-1",
reason=WorkflowStartReason.INITIAL,
)
assert resp.data.reason is WorkflowStartReason.INITIAL

View File

@ -23,6 +23,7 @@ from core.app.entities.queue_entities import (
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import NodeType
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
@ -124,7 +125,12 @@ class TestWorkflowResponseConverter:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -160,7 +166,12 @@ class TestWorkflowResponseConverter:
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -191,7 +202,12 @@ class TestWorkflowResponseConverter:
"""Test node finish response when process_data is None."""
converter = self.create_workflow_response_converter()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -225,7 +241,12 @@ class TestWorkflowResponseConverter:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -261,7 +282,12 @@ class TestWorkflowResponseConverter:
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -400,6 +426,7 @@ class TestWorkflowResponseConverterServiceApiTruncation:
task_id="test-task-id",
workflow_run_id="test-workflow-run-id",
workflow_id="test-workflow-id",
reason=WorkflowStartReason.INITIAL,
)
return converter

View File

@ -0,0 +1,139 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps import message_based_app_generator
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.task_pipeline import message_cycle_manager
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from models.model import AppMode, Conversation, Message
def _make_app_config() -> WorkflowUIBasedAppConfig:
return WorkflowUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=AppMode.ADVANCED_CHAT,
workflow_id="workflow-id",
additional_features=AppAdditionalFeatures(),
variables=[],
)
def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity:
return AdvancedChatAppGenerateEntity(
task_id="task-id",
app_config=app_config,
file_upload_config=None,
conversation_id=None,
inputs={},
query="hello",
files=[],
parent_message_id=None,
user_id="user-id",
stream=True,
invoke_from=InvokeFrom.WEB_APP,
extras={},
workflow_run_id="workflow-run-id",
)
@pytest.fixture(autouse=True)
def _mock_db_session(monkeypatch):
session = MagicMock()
def refresh_side_effect(obj):
if isinstance(obj, Conversation) and obj.id is None:
obj.id = "generated-conversation-id"
if isinstance(obj, Message) and obj.id is None:
obj.id = "generated-message-id"
session.refresh.side_effect = refresh_side_effect
session.add.return_value = None
session.commit.return_value = None
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
return session
def test_init_generate_records_sets_conversation_metadata():
app_config = _make_app_config()
entity = _make_generate_entity(app_config)
generator = AdvancedChatAppGenerator()
conversation, _ = generator._init_generate_records(entity, conversation=None)
assert entity.conversation_id == "generated-conversation-id"
assert conversation.id == "generated-conversation-id"
assert entity.is_new_conversation is True
def test_init_generate_records_marks_existing_conversation():
app_config = _make_app_config()
entity = _make_generate_entity(app_config)
existing_conversation = Conversation(
app_id=app_config.app_id,
app_model_config_id=None,
model_provider=None,
override_model_configs=None,
model_id=None,
mode=app_config.app_mode.value,
name="existing",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
from_source="api",
from_end_user_id="user-id",
from_account_id=None,
)
existing_conversation.id = "existing-conversation-id"
generator = AdvancedChatAppGenerator()
conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation)
assert entity.conversation_id == "existing-conversation-id"
assert conversation is existing_conversation
assert entity.is_new_conversation is False
def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
app_config = _make_app_config()
entity = _make_generate_entity(app_config)
entity.conversation_id = "existing-conversation-id"
entity.is_new_conversation = True
entity.extras = {"auto_generate_conversation_name": True}
captured = {}
class DummyThread:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.started = False
def start(self):
self.started = True
def fake_thread(**kwargs):
thread = DummyThread(**kwargs)
captured["thread"] = thread
return thread
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")
assert thread is captured["thread"]
assert thread.started is True
assert entity.is_new_conversation is False

View File

@ -0,0 +1,127 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import (
AppAdditionalFeatures,
EasyUIBasedAppConfig,
EasyUIBasedAppModelConfigFrom,
ModelConfigEntity,
PromptTemplateEntity,
)
from core.app.apps import message_based_app_generator
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from models.model import AppMode, Conversation, Message
class DummyModelConf:
def __init__(self, provider: str = "mock-provider", model: str = "mock-model") -> None:
self.provider = provider
self.model = model
class DummyCompletionGenerateEntity:
__slots__ = ("app_config", "invoke_from", "user_id", "query", "inputs", "files", "model_conf")
app_config: EasyUIBasedAppConfig
invoke_from: InvokeFrom
user_id: str
query: str
inputs: dict
files: list
model_conf: DummyModelConf
def __init__(self, app_config: EasyUIBasedAppConfig) -> None:
self.app_config = app_config
self.invoke_from = InvokeFrom.WEB_APP
self.user_id = "user-id"
self.query = "hello"
self.inputs = {}
self.files = []
self.model_conf = DummyModelConf()
def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig:
return EasyUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=app_mode,
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
app_model_config_id="model-config-id",
app_model_config_dict={},
model=ModelConfigEntity(provider="mock-provider", model="mock-model", mode="chat"),
prompt_template=PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="Hello",
),
additional_features=AppAdditionalFeatures(),
variables=[],
)
def _make_chat_generate_entity(app_config: EasyUIBasedAppConfig) -> ChatAppGenerateEntity:
return ChatAppGenerateEntity.model_construct(
task_id="task-id",
app_config=app_config,
model_conf=DummyModelConf(),
file_upload_config=None,
conversation_id=None,
inputs={},
query="hello",
files=[],
parent_message_id=None,
user_id="user-id",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
call_depth=0,
trace_manager=None,
)
@pytest.fixture(autouse=True)
def _mock_db_session(monkeypatch):
session = MagicMock()
def refresh_side_effect(obj):
if isinstance(obj, Conversation) and obj.id is None:
obj.id = "generated-conversation-id"
if isinstance(obj, Message) and obj.id is None:
obj.id = "generated-message-id"
session.refresh.side_effect = refresh_side_effect
session.add.return_value = None
session.commit.return_value = None
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
return session
def test_init_generate_records_skips_conversation_fields_for_non_conversation_entity():
app_config = _make_app_config(AppMode.COMPLETION)
entity = DummyCompletionGenerateEntity(app_config=app_config)
generator = MessageBasedAppGenerator()
conversation, message = generator._init_generate_records(entity, conversation=None)
assert conversation.id == "generated-conversation-id"
assert message.id == "generated-message-id"
assert hasattr(entity, "conversation_id") is False
assert hasattr(entity, "is_new_conversation") is False
def test_init_generate_records_sets_conversation_fields_for_chat_entity():
app_config = _make_app_config(AppMode.CHAT)
entity = _make_chat_generate_entity(app_config)
generator = MessageBasedAppGenerator()
conversation, _ = generator._init_generate_records(entity, conversation=None)
assert entity.conversation_id == "generated-conversation-id"
assert entity.is_new_conversation is True
assert conversation.id == "generated-conversation-id"

View File

@ -0,0 +1,287 @@
import sys
import time
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
API_DIR = str(Path(__file__).resolve().parents[5])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)
import core.workflow.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
if "core.ops.ops_trace_manager" not in sys.modules:
ops_stub = ModuleType("core.ops.ops_trace_manager")
class _StubTraceQueueManager:
def __init__(self, *_, **__):
pass
ops_stub.TraceQueueManager = _StubTraceQueueManager
sys.modules["core.ops.ops_trace_manager"] = ops_stub
class _StubToolNodeData(BaseNodeData):
pause_on: bool = False
class _StubToolNode(Node[_StubToolNodeData]):
node_type = NodeType.TOOL
@classmethod
def version(cls) -> str:
return "1"
def init_node_data(self, data):
self._node_data = _StubToolNodeData.model_validate(data)
def _get_error_strategy(self):
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self):
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self):
if self.node_data.pause_on:
yield PauseRequestedEvent(reason=SchedulingPause(message="test pause"))
return
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"value": f"{self.id}-done"},
)
yield self._convert_node_run_result_to_graph_node_event(result)
def _patch_tool_node(mocker):
original_create_node = DifyNodeFactory.create_node
def _patched_create_node(self, node_config: dict[str, object]) -> Node:
node_data = node_config.get("data", {})
if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value:
return _StubToolNode(
id=str(node_config["id"]),
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
return original_create_node(self, node_config)
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
node_data = data.model_dump()
node_data["type"] = node_type.value
return node_data
def _build_graph_config(*, pause_on: str | None) -> dict[str, object]:
start_data = StartNodeData(title="start", variables=[])
tool_data_a = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_a")
tool_data_b = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_b")
tool_data_c = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_c")
end_data = EndNodeData(
title="end",
outputs=[OutputVariableEntity(variable="result", value_selector=["tool_c", "value"])],
desc=None,
)
nodes = [
{"id": "start", "data": _node_data(NodeType.START, start_data)},
{"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)},
{"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)},
{"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)},
{"id": "end", "data": _node_data(NodeType.END, end_data)},
]
edges = [
{"source": "start", "target": "tool_a"},
{"source": "tool_a", "target": "tool_b"},
{"source": "tool_b", "target": "tool_c"},
{"source": "tool_c", "target": "end"},
]
return {"nodes": nodes, "edges": edges}
def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph:
graph_config = _build_graph_config(pause_on=pause_on)
params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="service-api",
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=params,
graph_runtime_state=runtime_state,
)
return Graph.init(graph_config=graph_config, node_factory=node_factory)
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
variable_pool.system_variables.workflow_execution_id = run_id
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
command_channel = InMemoryChannel()
graph = _build_graph(runtime_state, pause_on=pause_on)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=command_channel,
)
events: list[GraphEngineEvent] = []
for event in engine.run():
events.append(event)
return events
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
def test_workflow_app_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("paused-run")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = wf_app_gen_module.WorkflowAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
graph_runtime_state=resumed_state,
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs
def test_advanced_chat_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("adv-baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("adv-paused")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = adv_app_gen_module.AdvancedChatAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
conversation=SimpleNamespace(id="conv"),
message=SimpleNamespace(id="msg"),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
graph_runtime_state=resumed_state,
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs
def test_resume_emits_resumption_start_reason(mocker) -> None:
_patch_tool_node(mocker)
paused_state = _build_runtime_state("resume-reason")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
initial_start = next(event for event in paused_events if isinstance(event, GraphRunStartedEvent))
assert initial_start.reason == WorkflowStartReason.INITIAL
resumed_state = GraphRuntimeState.from_snapshot(paused_state.dumps())
resumed_events = _run_with_optional_pause(resumed_state, pause_on=None)
resume_start = next(event for event in resumed_events if isinstance(event, GraphRunStartedEvent))
assert resume_start.reason == WorkflowStartReason.RESUMPTION

View File

@ -0,0 +1,80 @@
from __future__ import annotations
import json
import queue
import pytest
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.task_entities import StreamEvent
from models.model import AppMode
class FakeSubscription:
def __init__(self, message_queue: queue.Queue[bytes], state: dict[str, bool]) -> None:
self._queue = message_queue
self._state = state
self._closed = False
def __enter__(self):
self._state["subscribed"] = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def close(self) -> None:
self._closed = True
def receive(self, timeout: float | None = 0.1) -> bytes | None:
if self._closed:
return None
try:
if timeout is None:
return self._queue.get()
return self._queue.get(timeout=timeout)
except queue.Empty:
return None
class FakeTopic:
def __init__(self) -> None:
self._queue: queue.Queue[bytes] = queue.Queue()
self._state = {"subscribed": False}
def subscribe(self) -> FakeSubscription:
return FakeSubscription(self._queue, self._state)
def publish(self, payload: bytes) -> None:
self._queue.put(payload)
@property
def subscribed(self) -> bool:
return self._state["subscribed"]
def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
topic = FakeTopic()
def fake_get_response_topic(cls, app_mode, workflow_run_id):
return topic
monkeypatch.setattr(MessageBasedAppGenerator, "get_response_topic", classmethod(fake_get_response_topic))
def on_subscribe() -> None:
assert topic.subscribed is True
event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
topic.publish(json.dumps(event).encode())
generator = MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
"workflow-run-id",
idle_timeout=0.5,
on_subscribe=on_subscribe,
)
assert next(generator) == StreamEvent.PING.value
event = next(generator)
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
with pytest.raises(StopIteration):
next(generator)

View File

@ -1,3 +1,6 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
@ -17,3 +20,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
def test_resume_delegates_to_generate(mocker):
generator = WorkflowAppGenerator()
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")
application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger")
runtime_state = MagicMock(name="runtime-state")
pause_config = MagicMock(name="pause-config")
result = generator.resume(
app_model=MagicMock(),
workflow=MagicMock(),
user=MagicMock(),
application_generate_entity=application_generate_entity,
graph_runtime_state=runtime_state,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
graph_engine_layers=("layer",),
pause_state_config=pause_config,
variable_loader=MagicMock(),
)
assert result == "ok"
mock_generate.assert_called_once()
kwargs = mock_generate.call_args.kwargs
assert kwargs["graph_runtime_state"] is runtime_state
assert kwargs["pause_state_config"] is pause_config
assert kwargs["streaming"] is False
assert kwargs["invoke_from"] == "debugger"
def test_generate_appends_pause_layer_and_forwards_state(mocker):
generator = WorkflowAppGenerator()
mock_queue_manager = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager)
fake_current_app = MagicMock()
fake_current_app._get_current_object.return_value = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app)
mocker.patch(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
return_value="converted",
)
mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response")
mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock())
pause_layer = MagicMock(name="pause-layer")
mocker.patch(
"core.app.apps.workflow.app_generator.PauseStatePersistenceLayer",
return_value=pause_layer,
)
dummy_session = MagicMock()
dummy_session.close = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session)
worker_kwargs: dict[str, object] = {}
class DummyThread:
def __init__(self, target, kwargs):
worker_kwargs["target"] = target
worker_kwargs["kwargs"] = kwargs
def start(self):
return None
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread)
app_model = SimpleNamespace(mode="workflow")
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf")
application_generate_entity = SimpleNamespace(
task_id="task",
user_id="user",
invoke_from="service-api",
app_config=app_config,
files=[],
stream=True,
workflow_execution_id="run",
)
graph_runtime_state = MagicMock()
result = generator._generate(
app_model=app_model,
workflow=MagicMock(),
user=MagicMock(),
application_generate_entity=application_generate_entity,
invoke_from="service-api",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
streaming=True,
graph_engine_layers=("base-layer",),
graph_runtime_state=graph_runtime_state,
pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"),
)
assert result == "converted"
assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer)
assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state
def test_resume_path_runs_worker_with_runtime_state(mocker):
generator = WorkflowAppGenerator()
runtime_state = MagicMock(name="runtime-state")
pause_layer = MagicMock(name="pause-layer")
mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer)
queue_manager = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager)
mocker.patch.object(generator, "_handle_response", return_value="raw-response")
mocker.patch(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
side_effect=lambda response, invoke_from: response,
)
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
workflow = SimpleNamespace(
id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1"
)
end_user = SimpleNamespace(session_id="end-user-session")
app_record = SimpleNamespace(id="app")
session = MagicMock()
session.__enter__.return_value = session
session.__exit__.return_value = False
session.scalar.side_effect = [workflow, end_user, app_record]
mocker.patch("core.app.apps.workflow.app_generator.session_factory", return_value=session)
runner_instance = MagicMock()
def runner_ctor(**kwargs):
assert kwargs["graph_runtime_state"] is runtime_state
return runner_instance
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor)
class ImmediateThread:
def __init__(self, target, kwargs):
target(**kwargs)
def start(self):
return None
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread)
mocker.patch(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=MagicMock(),
)
pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner")
app_model = SimpleNamespace(mode="workflow")
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow")
application_generate_entity = SimpleNamespace(
task_id="task",
user_id="user",
invoke_from="service-api",
app_config=app_config,
files=[],
stream=True,
workflow_execution_id="run",
trace_manager=MagicMock(),
)
result = generator.resume(
app_model=app_model,
workflow=workflow,
user=MagicMock(),
application_generate_entity=application_generate_entity,
graph_runtime_state=runtime_state,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
pause_state_config=pause_config,
)
assert result == "raw-response"
runner_instance.run.assert_called_once()
queue_manager.graph_runtime_state = runtime_state

View File

@ -0,0 +1,59 @@
from unittest.mock import MagicMock
import pytest
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph_events.graph import GraphRunPausedEvent
class _DummyQueueManager:
def __init__(self):
self.published = []
def publish(self, event, _from):
self.published.append(event)
class _DummyRuntimeState:
def get_paused_nodes(self):
return ["node-1"]
class _DummyGraphEngine:
def __init__(self):
self.graph_runtime_state = _DummyRuntimeState()
class _DummyWorkflowEntry:
def __init__(self):
self.graph_engine = _DummyGraphEngine()
def test_handle_pause_event_enqueues_email_task(monkeypatch: pytest.MonkeyPatch):
queue_manager = _DummyQueueManager()
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app-id")
workflow_entry = _DummyWorkflowEntry()
reason = HumanInputRequired(
form_id="form-123",
form_content="content",
inputs=[],
actions=[],
node_id="node-1",
node_title="Review",
)
event = GraphRunPausedEvent(reasons=[reason], outputs={})
email_task = MagicMock()
monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", email_task)
runner._handle_event(workflow_entry, event)
email_task.apply_async.assert_called_once()
kwargs = email_task.apply_async.call_args.kwargs["kwargs"]
assert kwargs["form_id"] == "form-123"
assert kwargs["node_title"] == "Review"
assert any(isinstance(evt, QueueWorkflowPausedEvent) for evt in queue_manager.published)

View File

@ -0,0 +1,183 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.apps.common import workflow_response_converter
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph_events.graph import GraphRunPausedEvent
from core.workflow.nodes.human_input.entities import FormInput, UserAction
from core.workflow.nodes.human_input.enums import FormInputType
from core.workflow.system_variable import SystemVariable
from models.account import Account
class _RecordingWorkflowAppRunner(WorkflowAppRunner):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.published_events = []
def _publish_event(self, event):
self.published_events.append(event)
class _FakeRuntimeState:
def get_paused_nodes(self):
return ["node-pause-1"]
def _build_runner():
app_entity = SimpleNamespace(
app_config=SimpleNamespace(app_id="app-id"),
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
single_iteration_run=None,
single_loop_run=None,
workflow_execution_id="run-id",
user_id="user-id",
)
workflow = SimpleNamespace(
graph_dict={},
tenant_id="tenant-id",
environment_variables={},
id="workflow-id",
)
queue_manager = SimpleNamespace(publish=lambda event, pub_from: None)
return _RecordingWorkflowAppRunner(
application_generate_entity=app_entity,
queue_manager=queue_manager,
variable_loader=MagicMock(),
workflow=workflow,
system_user_id="sys-user",
root_node_id=None,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
graph_engine_layers=(),
graph_runtime_state=None,
)
def test_graph_run_paused_event_emits_queue_pause_event():
runner = _build_runner()
reason = HumanInputRequired(
form_id="form-1",
form_content="content",
inputs=[],
actions=[],
node_id="node-human",
node_title="Human Step",
form_token="tok",
)
event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"})
workflow_entry = SimpleNamespace(
graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()),
)
runner._handle_event(workflow_entry, event)
assert len(runner.published_events) == 1
queue_event = runner.published_events[0]
assert isinstance(queue_event, QueueWorkflowPausedEvent)
assert queue_event.reasons == [reason]
assert queue_event.outputs == {"foo": "bar"}
assert queue_event.paused_nodes == ["node-pause-1"]
def _build_converter():
application_generate_entity = SimpleNamespace(
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
)
system_variables = SystemVariable(
user_id="user",
app_id="app-id",
workflow_id="workflow-id",
workflow_execution_id="run-id",
)
user = MagicMock(spec=Account)
user.id = "account-id"
user.name = "Tester"
user.email = "tester@example.com"
return WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=system_variables,
)
def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.MonkeyPatch):
converter = _build_converter()
converter.workflow_start_to_stream_response(
task_id="task",
workflow_run_id="run-id",
workflow_id="workflow-id",
reason=WorkflowStartReason.INITIAL,
)
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
class _FakeSession:
def execute(self, _stmt):
return [("form-1", expiration_time)]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
reason = HumanInputRequired(
form_id="form-1",
form_content="Rendered",
inputs=[
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
node_id="node-id",
node_title="Human Step",
form_token="token",
)
queue_event = QueueWorkflowPausedEvent(
reasons=[reason],
outputs={"answer": "value"},
paused_nodes=["node-id"],
)
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
responses = converter.workflow_pause_to_stream_response(
event=queue_event,
task_id="task",
graph_runtime_state=runtime_state,
)
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
pause_resp = responses[-1]
assert pause_resp.workflow_run_id == "run-id"
assert pause_resp.data.paused_nodes == ["node-id"]
assert pause_resp.data.outputs == {}
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
assert pause_resp.data.reasons[0]["display_in_ui"] is True
assert isinstance(responses[0], HumanInputRequiredResponse)
hi_resp = responses[0]
assert hi_resp.data.form_id == "form-1"
assert hi_resp.data.node_id == "node-id"
assert hi_resp.data.node_title == "Human Step"
assert hi_resp.data.inputs[0].output_variable_name == "field"
assert hi_resp.data.actions[0].id == "approve"
assert hi_resp.data.display_in_ui is True
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())

View File

@ -0,0 +1,96 @@
import time
from contextlib import contextmanager
from unittest.mock import MagicMock
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import QueueWorkflowStartedEvent
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.account import Account
from models.model import AppMode
def _build_workflow_app_config() -> WorkflowUIBasedAppConfig:
return WorkflowUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-id",
)
def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity:
return WorkflowAppGenerateEntity(
task_id="task-id",
app_config=_build_workflow_app_config(),
inputs={},
files=[],
user_id="user-id",
stream=False,
invoke_from=InvokeFrom.SERVICE_API,
workflow_execution_id=run_id,
)
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(workflow_execution_id=run_id),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
@contextmanager
def _noop_session():
yield MagicMock()
def _build_pipeline(run_id: str) -> WorkflowAppGenerateTaskPipeline:
queue_manager = MagicMock(spec=AppQueueManager)
queue_manager.invoke_from = InvokeFrom.SERVICE_API
queue_manager.graph_runtime_state = _build_runtime_state(run_id)
workflow = MagicMock()
workflow.id = "workflow-id"
workflow.features_dict = {}
user = Account(name="user", email="user@example.com")
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=_build_generate_entity(run_id),
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=False,
draft_var_saver_factory=MagicMock(),
)
pipeline._database_session = _noop_session
return pipeline
def test_workflow_app_log_saved_only_on_initial_start() -> None:
run_id = "run-initial"
pipeline = _build_pipeline(run_id)
pipeline._save_workflow_app_log = MagicMock()
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.INITIAL)
list(pipeline._handle_workflow_started_event(event))
pipeline._save_workflow_app_log.assert_called_once()
_, kwargs = pipeline._save_workflow_app_log.call_args
assert kwargs["workflow_run_id"] == run_id
assert pipeline._workflow_execution_id == run_id
def test_workflow_app_log_skipped_on_resumption_start() -> None:
run_id = "run-resume"
pipeline = _build_pipeline(run_id)
pipeline._save_workflow_app_log = MagicMock()
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.RESUMPTION)
list(pipeline._handle_workflow_started_event(event))
pipeline._save_workflow_app_log.assert_not_called()
assert pipeline._workflow_execution_id == run_id

View File

@ -0,0 +1,143 @@
import json
from collections.abc import Callable
from dataclasses import dataclass
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.layers.pause_state_persist_layer import (
WorkflowResumptionContext,
_AdvancedChatAppGenerateEntityWrapper,
_WorkflowGenerateEntityWrapper,
)
from core.ops.ops_trace_manager import TraceQueueManager
from models.model import AppMode
class TraceQueueManagerStub(TraceQueueManager):
"""Minimal TraceQueueManager stub that avoids Flask dependencies."""
def __init__(self):
# Skip parent initialization to avoid starting timers or accessing Flask globals.
pass
def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig:
return WorkflowUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=app_mode,
workflow_id=f"{app_mode.value}-workflow-id",
)
def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity:
return WorkflowAppGenerateEntity(
task_id="workflow-task",
app_config=_build_workflow_app_config(AppMode.WORKFLOW),
inputs={"topic": "serialization"},
files=[],
user_id="user-workflow",
stream=True,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=1,
trace_manager=trace_manager,
workflow_execution_id="workflow-exec-id",
extras={"external_trace_id": "trace-id"},
)
def _create_advanced_chat_generate_entity(
trace_manager: TraceQueueManager | None = None,
) -> AdvancedChatAppGenerateEntity:
return AdvancedChatAppGenerateEntity(
task_id="advanced-task",
app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT),
conversation_id="conversation-id",
inputs={"topic": "roundtrip"},
files=[],
user_id="user-advanced",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
query="Explain serialization",
extras={"auto_generate_conversation_name": True},
trace_manager=trace_manager,
workflow_run_id="workflow-run-id",
)
def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager():
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
serialized = entity.model_dump_json()
payload = json.loads(serialized)
assert "trace_manager" not in payload
restored = WorkflowAppGenerateEntity.model_validate_json(serialized)
assert restored.model_dump() == entity.model_dump()
assert restored.trace_manager is None
def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager():
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
serialized = entity.model_dump_json()
payload = json.loads(serialized)
assert "trace_manager" not in payload
restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized)
assert restored.model_dump() == entity.model_dump()
assert restored.trace_manager is None
@dataclass(frozen=True)
class ResumptionContextCase:
name: str
context_factory: Callable[[], tuple[WorkflowResumptionContext, type]]
def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]:
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
context = WorkflowResumptionContext(
serialized_graph_runtime_state=json.dumps({"state": "workflow"}),
generate_entity=_WorkflowGenerateEntityWrapper(entity=entity),
)
return context, WorkflowAppGenerateEntity
def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]:
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
context = WorkflowResumptionContext(
serialized_graph_runtime_state=json.dumps({"state": "advanced"}),
generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity),
)
return context, AdvancedChatAppGenerateEntity
@pytest.mark.parametrize(
"case",
[
pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"),
pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"),
],
)
def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase):
context, expected_type = case.context_factory()
serialized = context.dumps()
restored = WorkflowResumptionContext.loads(serialized)
assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state
entity = restored.get_generate_entity()
assert isinstance(entity, expected_type)
assert entity.model_dump() == context.get_generate_entity().model_dump()
assert entity.trace_manager is None

View File

@ -0,0 +1,72 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from models.model import AppMode
def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker):
workflow = MagicMock()
workflow.created_by = "owner-id"
app = MagicMock()
app.mode = AppMode.ADVANCED_CHAT
app.workflow = workflow
mocker.patch(
"core.plugin.backwards_invocation.app.db",
SimpleNamespace(engine=MagicMock()),
)
generator_spy = mocker.patch(
"core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate",
return_value={"result": "ok"},
)
result = PluginAppBackwardsInvocation.invoke_chat_app(
app=app,
user=MagicMock(),
conversation_id="conv-1",
query="hello",
stream=False,
inputs={"k": "v"},
files=[],
)
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"
def test_invoke_workflow_app_injects_pause_state_config(mocker):
workflow = MagicMock()
workflow.created_by = "owner-id"
app = MagicMock()
app.mode = AppMode.WORKFLOW
app.workflow = workflow
mocker.patch(
"core.plugin.backwards_invocation.app.db",
SimpleNamespace(engine=MagicMock()),
)
generator_spy = mocker.patch(
"core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate",
return_value={"result": "ok"},
)
result = PluginAppBackwardsInvocation.invoke_workflow_app(
app=app,
user=MagicMock(),
stream=False,
inputs={"k": "v"},
files=[],
)
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"

View File

@ -0,0 +1,574 @@
"""Unit tests for HumanInputFormRepositoryImpl private helpers."""
from __future__ import annotations
import dataclasses
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
HumanInputFormSubmissionRepository,
_WorkspaceMemberInfo,
)
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormDefinition,
MemberRecipient,
UserAction,
)
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.datetime_utils import naive_utc_now
from models.human_input import (
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputFormRecipient,
RecipientType,
)
def _build_repository() -> HumanInputFormRepositoryImpl:
return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id")
def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]:
created: list[SimpleNamespace] = []
def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def]
recipient = SimpleNamespace(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=payload.TYPE,
recipient_payload=payload.model_dump_json(),
)
created.append(recipient)
return recipient
monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new))
return created
@pytest.fixture(autouse=True)
def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None:
"""Avoid SQLAlchemy mapper configuration in tests using fake sessions."""
class _FakeSelect:
def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
return self
def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
return self
monkeypatch.setattr(
"core.repositories.human_input_repository.selectinload", lambda *args, **kwargs: "_loader_option"
)
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *args, **kwargs: _FakeSelect())
class TestHumanInputFormRepositoryImplHelpers:
def test_build_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert restrict_to_user_ids == ["member-1"]
return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(user_id="member-1"),
ExternalRecipient(email="external@example.com"),
],
),
)
assert len(recipients) == 2
member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER)
external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL)
member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload)
assert member_payload.user_id == "member-1"
assert member_payload.email == "member@example.com"
external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload)
assert external_payload.email == "external@example.com"
def test_build_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
created = _patch_recipient_factory(monkeypatch)
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert restrict_to_user_ids == ["missing-member"]
return []
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(user_id="missing-member"),
ExternalRecipient(email="external@example.com"),
],
),
)
assert len(recipients) == 1
assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL
assert len(created) == 1 # only external recipient created via factory
def test_build_email_recipients_whole_workspace_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session): # type: ignore[no-untyped-def]
assert session is session_stub
return [
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=True,
items=[],
),
)
assert len(recipients) == 2
emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients}
assert emails == {"member1@example.com", "member2@example.com"}
def test_build_email_recipients_dedupes_external_by_email(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
created = _patch_recipient_factory(monkeypatch)
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert restrict_to_user_ids == []
return []
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[
ExternalRecipient(email="external@example.com"),
ExternalRecipient(email="external@example.com"),
],
),
)
assert len(recipients) == 1
assert len(created) == 1
def test_build_email_recipients_prefers_member_over_external_by_email(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert restrict_to_user_ids == ["member-1"]
return [_WorkspaceMemberInfo(user_id="member-1", email="shared@example.com")]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(user_id="member-1"),
ExternalRecipient(email="shared@example.com"),
],
),
)
assert len(recipients) == 1
assert recipients[0].recipient_type == RecipientType.EMAIL_MEMBER
def test_delivery_method_to_model_includes_external_recipients_with_whole_workspace(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session): # type: ignore[no-untyped-def]
assert session is session_stub
return [
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query)
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=True,
items=[ExternalRecipient(email="external@example.com")],
),
subject="subject",
body="body",
)
)
result = repo._delivery_method_to_model(session=session_stub, form_id="form-id", delivery_method=method)
assert len(result.recipients) == 3
member_emails = {
EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email
for r in result.recipients
if r.recipient_type == RecipientType.EMAIL_MEMBER
}
assert member_emails == {"member1@example.com", "member2@example.com"}
external_payload = EmailExternalRecipientPayload.model_validate_json(
next(r for r in result.recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL).recipient_payload
)
assert external_payload.email == "external@example.com"
def _make_form_definition() -> str:
return FormDefinition(
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
rendered_content="<p>hello</p>",
expiration_time=datetime.utcnow(),
).model_dump_json()
@dataclasses.dataclass
class _DummyForm:
id: str
workflow_run_id: str
node_id: str
tenant_id: str
app_id: str
form_definition: str
rendered_content: str
expiration_time: datetime
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
selected_action_id: str | None = None
submitted_data: str | None = None
submitted_at: datetime | None = None
submission_user_id: str | None = None
submission_end_user_id: str | None = None
completed_by_recipient_id: str | None = None
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
@dataclasses.dataclass
class _DummyRecipient:
id: str
form_id: str
recipient_type: RecipientType
access_token: str
form: _DummyForm | None = None
class _FakeScalarResult:
def __init__(self, obj):
self._obj = obj
def first(self):
if isinstance(self._obj, list):
return self._obj[0] if self._obj else None
return self._obj
def all(self):
if isinstance(self._obj, list):
return list(self._obj)
if self._obj is None:
return []
return [self._obj]
class _FakeSession:
def __init__(
self,
*,
scalars_result=None,
scalars_results: list[object] | None = None,
forms: dict[str, _DummyForm] | None = None,
recipients: dict[str, _DummyRecipient] | None = None,
):
if scalars_results is not None:
self._scalars_queue = list(scalars_results)
elif scalars_result is not None:
self._scalars_queue = [scalars_result]
else:
self._scalars_queue = []
self.forms = forms or {}
self.recipients = recipients or {}
def scalars(self, _query):
if self._scalars_queue:
result = self._scalars_queue.pop(0)
else:
result = None
return _FakeScalarResult(result)
def get(self, model_cls, obj_id): # type: ignore[no-untyped-def]
if getattr(model_cls, "__name__", None) == "HumanInputForm":
return self.forms.get(obj_id)
if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient":
return self.recipients.get(obj_id)
return None
def add(self, _obj):
return None
def flush(self):
return None
def refresh(self, _obj):
return None
def begin(self):
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def _session_factory(session: _FakeSession):
class _SessionContext:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return None
def _factory(*_args, **_kwargs):
return _SessionContext()
return _factory
class TestHumanInputFormRepositoryImplPublicMethods:
def test_get_form_returns_entity_and_recipients(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-id",
app_id="app-id",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token-123",
)
session = _FakeSession(scalars_results=[form, [recipient]])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
entity = repo.get_form(form.workflow_run_id, form.node_id)
assert entity is not None
assert entity.id == form.id
assert entity.web_app_token == "token-123"
assert len(entity.recipients) == 1
assert entity.recipients[0].token == "token-123"
def test_get_form_returns_none_when_missing(self):
session = _FakeSession(scalars_results=[None])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
assert repo.get_form("run-1", "node-1") is None
def test_get_form_returns_unsubmitted_state(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-id",
app_id="app-id",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
entity = repo.get_form(form.workflow_run_id, form.node_id)
assert entity is not None
assert entity.submitted is False
assert entity.selected_action_id is None
assert entity.submitted_data is None
def test_get_form_returns_submission_when_completed(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-id",
app_id="app-id",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
selected_action_id="approve",
submitted_data='{"field": "value"}',
submitted_at=naive_utc_now(),
)
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
entity = repo.get_form(form.workflow_run_id, form.node_id)
assert entity is not None
assert entity.submitted is True
assert entity.selected_action_id == "approve"
assert entity.submitted_data == {"field": "value"}
class TestHumanInputFormSubmissionRepository:
def test_get_by_token_returns_record(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
app_id="app-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token-123",
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormSubmissionRepository(_session_factory(session))
record = repo.get_by_token("token-123")
assert record is not None
assert record.form_id == form.id
assert record.recipient_type == RecipientType.STANDALONE_WEB_APP
assert record.submitted is False
def test_get_by_form_id_and_recipient_type_uses_recipient(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
app_id="app-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token-123",
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormSubmissionRepository(_session_factory(session))
record = repo.get_by_form_id_and_recipient_type(
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
)
assert record is not None
assert record.recipient_id == recipient.id
assert record.access_token == recipient.access_token
def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch):
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
app_id="app-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=fixed_now,
)
recipient = _DummyRecipient(
id="recipient-1",
form_id="form-1",
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token-123",
)
session = _FakeSession(
forms={form.id: form},
recipients={recipient.id: recipient},
)
repo = HumanInputFormSubmissionRepository(_session_factory(session))
record: HumanInputFormRecord = repo.mark_submitted(
form_id=form.id,
recipient_id=recipient.id,
selected_action_id="approve",
form_data={"field": "value"},
submission_user_id="user-1",
submission_end_user_id="end-user-1",
)
assert form.selected_action_id == "approve"
assert form.completed_by_recipient_id == recipient.id
assert form.submission_user_id == "user-1"
assert form.submission_end_user_id == "end-user-1"
assert form.submitted_at == fixed_now
assert record.submitted is True
assert record.selected_action_id == "approve"
assert record.submitted_data == {"field": "value"}

View File

@ -0,0 +1,33 @@
import pytest
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
def test_ensure_no_human_input_nodes_passes_for_non_human_input():
graph = {
"nodes": [
{
"id": "start_node",
"data": {"type": "start"},
}
]
}
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
def test_ensure_no_human_input_nodes_raises_for_human_input():
graph = {
"nodes": [
{
"id": "human_input_node",
"data": {"type": "human-input"},
}
]
}
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"

View File

@ -55,6 +55,43 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
assert exc_info.value.args == ("oops",)
def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch):
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
from unittest.mock import MagicMock, Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
generate_mock = MagicMock(return_value={"data": {}})
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
list(tool.invoke("test_user", {}))
call_kwargs = generate_mock.call_args.kwargs
assert "pause_state_config" in call_kwargs
assert call_kwargs["pause_state_config"] is None
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
"""Test that WorkflowTool should generate variable messages when there are outputs"""
entity = ToolEntity(

View File

@ -118,7 +118,6 @@ class TestGraphRuntimeState:
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
assert isinstance(queue, InMemoryReadyQueue)
assert state.ready_queue is queue
def test_graph_execution_lazy_instantiation(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())

View File

@ -0,0 +1,88 @@
"""
Tests for PauseReason discriminated union serialization/deserialization.
"""
import pytest
from pydantic import BaseModel, ValidationError
from core.workflow.entities.pause_reason import (
HumanInputRequired,
PauseReason,
SchedulingPause,
)
class _Holder(BaseModel):
"""Helper model that embeds PauseReason for union tests."""
reason: PauseReason
class TestPauseReasonDiscriminator:
"""Test suite for PauseReason union discriminator."""
@pytest.mark.parametrize(
("dict_value", "expected"),
[
pytest.param(
{
"reason": {
"TYPE": "human_input_required",
"form_id": "form_id",
"form_content": "form_content",
"node_id": "node_id",
"node_title": "node_title",
},
},
HumanInputRequired(
form_id="form_id",
form_content="form_content",
node_id="node_id",
node_title="node_title",
),
id="HumanInputRequired",
),
pytest.param(
{
"reason": {
"TYPE": "scheduled_pause",
"message": "Hold on",
}
},
SchedulingPause(message="Hold on"),
id="SchedulingPause",
),
],
)
def test_model_validate(self, dict_value, expected):
"""Ensure scheduled pause payloads with lowercase TYPE deserialize."""
holder = _Holder.model_validate(dict_value)
assert type(holder.reason) == type(expected)
assert holder.reason == expected
@pytest.mark.parametrize(
"reason",
[
HumanInputRequired(
form_id="form_id",
form_content="form_content",
node_id="node_id",
node_title="node_title",
),
SchedulingPause(message="Hold on"),
],
ids=lambda x: type(x).__name__,
)
def test_model_construct(self, reason):
holder = _Holder(reason=reason)
assert holder.reason == reason
def test_model_construct_with_invalid_type(self):
with pytest.raises(ValidationError):
holder = _Holder(reason=object()) # type: ignore
def test_unknown_type_fails_validation(self):
"""Unknown TYPE values should raise a validation error."""
with pytest.raises(ValidationError):
_Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}})

View File

@ -0,0 +1,131 @@
"""Utilities for testing HumanInputNode without database dependencies."""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
HumanInputFormRepository,
)
from libs.datetime_utils import naive_utc_now
class _InMemoryFormRecipient(HumanInputFormRecipientEntity):
"""Minimal recipient entity required by the repository interface."""
def __init__(self, recipient_id: str, token: str) -> None:
self._id = recipient_id
self._token = token
@property
def id(self) -> str:
return self._id
@property
def token(self) -> str:
return self._token
@dataclass
class _InMemoryFormEntity(HumanInputFormEntity):
form_id: str
rendered: str
token: str | None = None
action_id: str | None = None
data: Mapping[str, Any] | None = None
is_submitted: bool = False
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now()
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return self.token
@property
def recipients(self) -> list[HumanInputFormRecipientEntity]:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
"""Pure in-memory repository used by workflow graph engine tests."""
def __init__(self) -> None:
self._form_counter = 0
self.created_params: list[FormCreateParams] = []
self.created_forms: list[_InMemoryFormEntity] = []
self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {}
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
self.created_params.append(params)
self._form_counter += 1
form_id = f"form-{self._form_counter}"
token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}"
entity = _InMemoryFormEntity(
form_id=form_id,
rendered=params.rendered_content,
token=token,
)
self.created_forms.append(entity)
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity
return entity
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
return self._forms_by_key.get((workflow_execution_id, node_id))
# Convenience helpers for tests -------------------------------------
def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None:
"""Simulate a human submission for the next repository lookup."""
if not self.created_forms:
raise AssertionError("no form has been created to attach submission data")
entity = self.created_forms[-1]
entity.action_id = action_id
entity.data = form_data or {}
entity.is_submitted = True
entity.status_value = HumanInputFormStatus.SUBMITTED
entity.expiration = naive_utc_now() + timedelta(days=1)
def clear_submission(self) -> None:
if not self.created_forms:
return
for form in self.created_forms:
form.action_id = None
form.data = None
form.is_submitted = False
form.status_value = HumanInputFormStatus.WAITING

View File

@ -0,0 +1,74 @@
import queue
import threading
from datetime import datetime
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher
from core.workflow.graph_events import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
class StubExecutionCoordinator:
def __init__(self, paused: bool) -> None:
self._paused = paused
self.mark_complete_called = False
self.failed_error: Exception | None = None
@property
def aborted(self) -> bool:
return False
@property
def paused(self) -> bool:
return self._paused
@property
def execution_complete(self) -> bool:
return False
def check_scaling(self) -> None:
return None
def process_commands(self) -> None:
return None
def mark_complete(self) -> None:
self.mark_complete_called = True
def mark_failed(self, error: Exception) -> None:
self.failed_error = error
class StubEventHandler:
def __init__(self) -> None:
self.events: list[object] = []
def dispatch(self, event: object) -> None:
self.events.append(event)
def test_dispatcher_drains_events_when_paused() -> None:
event_queue: queue.Queue = queue.Queue()
event = NodeRunSucceededEvent(
id="exec-1",
node_id="node-1",
node_type=NodeType.START,
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
)
event_queue.put(event)
handler = StubEventHandler()
coordinator = StubExecutionCoordinator(paused=True)
dispatcher = Dispatcher(
event_queue=event_queue,
event_handler=handler,
execution_coordinator=coordinator,
event_emitter=None,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
assert handler.events == [event]
assert coordinator.mark_complete_called is True

View File

@ -2,6 +2,8 @@
from unittest.mock import MagicMock
import pytest
from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
@ -48,3 +50,13 @@ def test_handle_pause_noop_when_execution_running() -> None:
worker_pool.stop.assert_not_called()
state_manager.clear_executing.assert_not_called()
def test_has_executing_nodes_requires_pause() -> None:
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
coordinator, _, _ = _build_coordinator(graph_execution)
with pytest.raises(AssertionError):
coordinator.has_executing_nodes()

View File

@ -0,0 +1,189 @@
import time
from collections.abc import Mapping
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeState
from core.workflow.graph import Graph
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_llm_node(
*,
node_id: str,
runtime_state: GraphRuntimeState,
graph_init_params: GraphInitParams,
mock_config: MockConfig,
) -> MockLLMNode:
llm_data = LLMNodeData(
title=f"LLM {node_id}",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=f"Prompt {node_id}",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_config = {"id": node_id, "data": llm_data.model_dump()}
return MockLLMNode(
id=llm_config["id"],
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
mock_config = MockConfig()
llm_a = _build_llm_node(
node_id="llm_a",
runtime_state=runtime_state,
graph_init_params=graph_init_params,
mock_config=mock_config,
)
llm_b = _build_llm_node(
node_id="llm_b",
runtime_state=runtime_state,
graph_init_params=graph_init_params,
mock_config=mock_config,
)
end_data = EndNodeData(title="End", outputs=[], desc=None)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
id=end_config["id"],
config=end_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
builder = (
Graph.new()
.add_root(start_node)
.add_node(llm_a, from_node_id="start")
.add_node(llm_b, from_node_id="start")
.add_node(end_node, from_node_id="llm_a")
)
return builder.connect(tail="llm_b", head="end").build()
def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]:
return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()}
def test_runtime_state_snapshot_restores_graph_states() -> None:
runtime_state = _build_runtime_state()
graph = _build_graph(runtime_state)
runtime_state.attach_graph(graph)
graph.nodes["llm_a"].state = NodeState.TAKEN
graph.nodes["llm_b"].state = NodeState.SKIPPED
for edge in graph.edges.values():
if edge.tail == "start" and edge.head == "llm_a":
edge.state = NodeState.TAKEN
elif edge.tail == "start" and edge.head == "llm_b":
edge.state = NodeState.SKIPPED
elif edge.head == "end" and edge.tail == "llm_a":
edge.state = NodeState.TAKEN
elif edge.head == "end" and edge.tail == "llm_b":
edge.state = NodeState.SKIPPED
snapshot = runtime_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resumed_graph = _build_graph(resumed_state)
resumed_state.attach_graph(resumed_graph)
assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN
assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED
assert _edge_state_map(resumed_graph) == _edge_state_map(graph)
def test_join_readiness_uses_restored_edge_states() -> None:
runtime_state = _build_runtime_state()
graph = _build_graph(runtime_state)
runtime_state.attach_graph(graph)
ready_queue = InMemoryReadyQueue()
state_manager = GraphStateManager(graph, ready_queue)
for edge in graph.get_incoming_edges("end"):
if edge.tail == "llm_a":
edge.state = NodeState.TAKEN
if edge.tail == "llm_b":
edge.state = NodeState.UNKNOWN
assert state_manager.is_node_ready("end") is False
for edge in graph.get_incoming_edges("end"):
if edge.tail == "llm_b":
edge.state = NodeState.TAKEN
assert state_manager.is_node_ready("end") is True
snapshot = runtime_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resumed_graph = _build_graph(resumed_state)
resumed_state.attach_graph(resumed_graph)
resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue())
assert resumed_state_manager.is_node_ready("end") is True

View File

@ -1,5 +1,7 @@
import datetime
import time
from collections.abc import Iterable
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
@ -14,11 +16,12 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
@ -28,15 +31,21 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
def _build_branching_graph(
mock_config: MockConfig,
form_repository: HumanInputFormRepository,
graph_runtime_state: GraphRuntimeState | None = None,
) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
@ -49,12 +58,18 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
if graph_runtime_state is None:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="test-execution-id",
),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
@ -93,15 +108,21 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
human_data = HumanInputNodeData(
title="Human Input",
required_variables=["human.input_ready"],
pause_reason="Awaiting human input",
form_content="Human input required",
inputs=[],
user_actions=[
UserAction(id="primary", title="Primary"),
UserAction(id="secondary", title="Secondary"),
],
)
human_config = {"id": "human", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=form_repository,
)
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
@ -219,8 +240,18 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
for scenario in branch_scenarios:
runner = TableTestRunner()
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
return _build_branching_graph(mock_config)
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
mock_create_repo.get_form.return_value = None
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_form_entity.submitted = False
mock_create_repo.create_form.return_value = mock_form_entity
def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]:
return _build_branching_graph(mock_config, mock_create_repo)
initial_case = WorkflowTestCase(
description="HumanInput pause before branching decision",
@ -242,23 +273,16 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
assert initial_result.success, initial_result.event_mismatch_details
assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events)
graph_runtime_state = initial_result.graph_runtime_state
graph = initial_result.graph
assert graph_runtime_state is not None
assert graph is not None
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"])
graph_runtime_state.graph_execution.pause_reason = None
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
expected_pre_chunk_events_in_resumption = [
GraphRunStartedEvent,
NodeRunStartedEvent,
NodeRunHumanInputFormFilledEvent,
]
expected_resume_sequence: list[type] = (
[
GraphRunStartedEvent,
NodeRunStartedEvent,
]
expected_pre_chunk_events_in_resumption
+ [NodeRunStreamChunkEvent] * pre_chunk_count
+ [
NodeRunSucceededEvent,
@ -273,11 +297,25 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
]
)
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
submitted_form = MagicMock(spec=HumanInputFormEntity)
submitted_form.id = mock_form_entity.id
submitted_form.web_app_token = mock_form_entity.web_app_token
submitted_form.recipients = []
submitted_form.rendered_content = mock_form_entity.rendered_content
submitted_form.submitted = True
submitted_form.selected_action_id = scenario["handle"]
submitted_form.submitted_data = {}
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory(
graph_snapshot: Graph = graph,
state_snapshot: GraphRuntimeState = graph_runtime_state,
initial_result=initial_result, mock_get_repo=mock_get_repo
) -> tuple[Graph, GraphRuntimeState]:
return graph_snapshot, state_snapshot
assert initial_result.graph_runtime_state is not None
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state)
resume_case = WorkflowTestCase(
description=f"HumanInput resumes via {scenario['handle']} branch",
@ -321,7 +359,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
]
assert pre_indices == list(range(2, 2 + pre_chunk_count))
expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption)
assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index))
resume_chunk_indices = [
index

View File

@ -1,4 +1,6 @@
import datetime
import time
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
@ -13,11 +15,12 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
@ -27,15 +30,21 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
def _build_llm_human_llm_graph(
mock_config: MockConfig,
form_repository: HumanInputFormRepository,
graph_runtime_state: GraphRuntimeState | None = None,
) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
@ -48,12 +57,15 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
if graph_runtime_state is None:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id,"
),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
@ -92,15 +104,21 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
human_data = HumanInputNodeData(
title="Human Input",
required_variables=["human.input_ready"],
pause_reason="Awaiting human input",
form_content="Human input required",
inputs=[],
user_actions=[
UserAction(id="accept", title="Accept"),
UserAction(id="reject", title="Reject"),
],
)
human_config = {"id": "human", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=form_repository,
)
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
@ -130,7 +148,7 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
.add_root(start_node)
.add_node(llm_first)
.add_node(human_node)
.add_node(llm_second)
.add_node(llm_second, source_handle="accept")
.add_node(end_node)
.build()
)
@ -167,8 +185,18 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
GraphRunPausedEvent, # graph run pauses awaiting resume
]
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
mock_create_repo.get_form.return_value = None
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_form_entity.submitted = False
mock_create_repo.create_form.return_value = mock_form_entity
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
return _build_llm_human_llm_graph(mock_config)
return _build_llm_human_llm_graph(mock_config, mock_create_repo)
initial_case = WorkflowTestCase(
description="HumanInput pause preserves LLM streaming order",
@ -210,6 +238,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
expected_resume_sequence: list[type] = [
GraphRunStartedEvent, # resumed graph run begins
NodeRunStartedEvent, # human node restarts
# Form Filled should be generated first, then the node execution ends and stream chunk is generated.
NodeRunHumanInputFormFilledEvent,
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
NodeRunStreamChunkEvent, # cached llm_initial final chunk
@ -225,12 +255,27 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
GraphRunSucceededEvent, # graph run succeeds after resume
]
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
submitted_form = MagicMock(spec=HumanInputFormEntity)
submitted_form.id = mock_form_entity.id
submitted_form.web_app_token = mock_form_entity.web_app_token
submitted_form.recipients = []
submitted_form.rendered_content = mock_form_entity.rendered_content
submitted_form.submitted = True
submitted_form.selected_action_id = "accept"
submitted_form.submitted_data = {}
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
assert graph_runtime_state is not None
assert graph is not None
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
graph_runtime_state.graph_execution.pause_reason = None
return graph, graph_runtime_state
# restruct the graph runtime state
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
return _build_llm_human_llm_graph(
mock_config,
mock_get_repo,
resume_runtime_state,
)
resume_case = WorkflowTestCase(
description="HumanInput resume continues LLM streaming order",

View File

@ -0,0 +1,270 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Protocol
from core.workflow.entities import GraphInitParams
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.config import GraphEngineConfig
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import OutputVariableEntity
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
class PauseStateStore(Protocol):
def save(self, runtime_state: GraphRuntimeState) -> None: ...
def load(self) -> GraphRuntimeState: ...
class InMemoryPauseStore:
def __init__(self) -> None:
self._snapshot: str | None = None
def save(self, runtime_state: GraphRuntimeState) -> None:
self._snapshot = runtime_state.dumps()
def load(self) -> GraphRuntimeState:
assert self._snapshot is not None
return GraphRuntimeState.from_snapshot(self._snapshot)
@dataclass
class StaticForm(HumanInputFormEntity):
form_id: str
rendered: str
is_submitted: bool
action_id: str | None = None
data: Mapping[str, Any] | None = None
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now() + timedelta(days=1)
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return "token"
@property
def recipients(self) -> list:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class StaticRepo(HumanInputFormRepository):
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
self._forms_by_node_id = dict(forms_by_node_id)
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
return self._forms_by_node_id.get(node_id)
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
raise AssertionError("create_form should not be called in resume scenario")
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
human_data = HumanInputNodeData(
title="Human Input",
form_content="Human input required",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
)
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
id=human_a_config["id"],
config=human_a_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
id=human_b_config["id"],
config=human_b_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
end_data = EndNodeData(
title="End",
outputs=[
OutputVariableEntity(variable="res_a", value_selector=["human_a", "__action_id"]),
OutputVariableEntity(variable="res_b", value_selector=["human_b", "__action_id"]),
],
desc=None,
)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
id=end_config["id"],
config=end_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
builder = (
Graph.new()
.add_root(start_node)
.add_node(human_a, from_node_id="start")
.add_node(human_b, from_node_id="start")
.add_node(end_node, from_node_id="human_a", source_handle="approve")
)
return builder.connect(tail="human_b", head="end", source_handle="approve").build()
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[object]:
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
return list(engine.run())
def _form(submitted: bool, action_id: str | None) -> StaticForm:
return StaticForm(
form_id="form",
rendered="rendered",
is_submitted=submitted,
action_id=action_id,
data={},
status_value=HumanInputFormStatus.SUBMITTED if submitted else HumanInputFormStatus.WAITING,
)
def test_parallel_human_input_join_completes_after_second_resume() -> None:
pause_store: PauseStateStore = InMemoryPauseStore()
initial_state = _build_runtime_state()
initial_repo = StaticRepo(
{
"human_a": _form(submitted=False, action_id=None),
"human_b": _form(submitted=False, action_id=None),
}
)
initial_graph = _build_graph(initial_state, initial_repo)
initial_events = _run_graph(initial_graph, initial_state)
assert isinstance(initial_events[-1], GraphRunPausedEvent)
pause_store.save(initial_state)
first_resume_state = pause_store.load()
first_resume_repo = StaticRepo(
{
"human_a": _form(submitted=True, action_id="approve"),
"human_b": _form(submitted=False, action_id=None),
}
)
first_resume_graph = _build_graph(first_resume_state, first_resume_repo)
first_resume_events = _run_graph(first_resume_graph, first_resume_state)
assert isinstance(first_resume_events[0], GraphRunStartedEvent)
assert first_resume_events[0].reason is WorkflowStartReason.RESUMPTION
assert isinstance(first_resume_events[-1], GraphRunPausedEvent)
pause_store.save(first_resume_state)
second_resume_state = pause_store.load()
second_resume_repo = StaticRepo(
{
"human_a": _form(submitted=True, action_id="approve"),
"human_b": _form(submitted=True, action_id="approve"),
}
)
second_resume_graph = _build_graph(second_resume_state, second_resume_repo)
second_resume_events = _run_graph(second_resume_graph, second_resume_state)
assert isinstance(second_resume_events[0], GraphRunStartedEvent)
assert second_resume_events[0].reason is WorkflowStartReason.RESUMPTION
assert isinstance(second_resume_events[-1], GraphRunSucceededEvent)
assert any(isinstance(event, NodeRunSucceededEvent) and event.node_id == "end" for event in second_resume_events)

View File

@ -0,0 +1,333 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.config import GraphEngineConfig
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
NodeRunPauseRequestedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig, NodeMockConfig
from .test_mock_nodes import MockLLMNode
@dataclass
class StaticForm(HumanInputFormEntity):
form_id: str
rendered: str
is_submitted: bool
action_id: str | None = None
data: Mapping[str, Any] | None = None
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now() + timedelta(days=1)
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return "token"
@property
def recipients(self) -> list:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class StaticRepo(HumanInputFormRepository):
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
self._forms_by_node_id = dict(forms_by_node_id)
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
return self._forms_by_node_id.get(node_id)
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
raise AssertionError("create_form should not be called in resume scenario")
class DelayedHumanInputNode(HumanInputNode):
def __init__(self, delay_seconds: float, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._delay_seconds = delay_seconds
def _run(self):
if self._delay_seconds > 0:
time.sleep(self._delay_seconds)
yield from super()._run()
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
human_data = HumanInputNodeData(
title="Human Input",
form_content="Human input required",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
)
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
id=human_a_config["id"],
config=human_a_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = DelayedHumanInputNode(
id=human_b_config["id"],
config=human_b_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
delay_seconds=0.2,
)
llm_data = LLMNodeData(
title="LLM A",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text="Prompt A",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
structured_output_enabled=False,
)
llm_config = {"id": "llm_a", "data": llm_data.model_dump()}
llm_a = MockLLMNode(
id=llm_config["id"],
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
return (
Graph.new()
.add_root(start_node)
.add_node(human_a, from_node_id="start")
.add_node(human_b, from_node_id="start")
.add_node(llm_a, from_node_id="human_a", source_handle="approve")
.build()
)
def test_parallel_human_input_pause_preserves_node_finished() -> None:
runtime_state = _build_runtime_state()
runtime_state.graph_execution.start()
runtime_state.register_paused_node("human_a")
runtime_state.register_paused_node("human_b")
submitted = StaticForm(
form_id="form-a",
rendered="rendered",
is_submitted=True,
action_id="approve",
data={},
status_value=HumanInputFormStatus.SUBMITTED,
)
pending = StaticForm(
form_id="form-b",
rendered="rendered",
is_submitted=False,
action_id=None,
data=None,
status_value=HumanInputFormStatus.WAITING,
)
repo = StaticRepo({"human_a": submitted, "human_b": pending})
mock_config = MockConfig()
mock_config.simulate_delays = True
mock_config.set_node_config(
"llm_a",
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
)
graph = _build_graph(runtime_state, repo, mock_config)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
events = list(engine.run())
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events)
assert graph_started
assert graph_paused
assert human_b_pause
assert llm_started
assert llm_succeeded
def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None:
base_state = _build_runtime_state()
base_state.graph_execution.start()
base_state.register_paused_node("human_a")
base_state.register_paused_node("human_b")
snapshot = base_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
submitted = StaticForm(
form_id="form-a",
rendered="rendered",
is_submitted=True,
action_id="approve",
data={},
status_value=HumanInputFormStatus.SUBMITTED,
)
pending = StaticForm(
form_id="form-b",
rendered="rendered",
is_submitted=False,
action_id=None,
data=None,
status_value=HumanInputFormStatus.WAITING,
)
repo = StaticRepo({"human_a": submitted, "human_b": pending})
mock_config = MockConfig()
mock_config.simulate_delays = True
mock_config.set_node_config(
"llm_a",
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
)
graph = _build_graph(resumed_state, repo, mock_config)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=resumed_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
events = list(engine.run())
start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent))
assert start_event.reason is WorkflowStartReason.RESUMPTION
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
assert graph_paused
assert human_b_pause
assert llm_started
assert llm_succeeded

View File

@ -0,0 +1,309 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.config import GraphEngineConfig
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig, NodeMockConfig
from .test_mock_nodes import MockLLMNode
@dataclass
class StaticForm(HumanInputFormEntity):
form_id: str
rendered: str
is_submitted: bool
action_id: str | None = None
data: Mapping[str, Any] | None = None
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now() + timedelta(days=1)
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return "token"
@property
def recipients(self) -> list:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class StaticRepo(HumanInputFormRepository):
def __init__(self, form: HumanInputFormEntity) -> None:
self._form = form
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
if node_id != "human_pause":
return None
return self._form
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
raise AssertionError("create_form should not be called in this test")
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
llm_a_data = LLMNodeData(
title="LLM A",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text="Prompt A",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
structured_output_enabled=False,
)
llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()}
llm_a = MockLLMNode(
id=llm_a_config["id"],
config=llm_a_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
llm_b_data = LLMNodeData(
title="LLM B",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text="Prompt B",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
structured_output_enabled=False,
)
llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()}
llm_b = MockLLMNode(
id=llm_b_config["id"],
config=llm_b_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
human_data = HumanInputNodeData(
title="Human Input",
form_content="Pause here",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
)
human_config = {"id": "human_pause", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
end_human_data = EndNodeData(title="End Human", outputs=[], desc=None)
end_human_config = {"id": "end_human", "data": end_human_data.model_dump()}
end_human = EndNode(
id=end_human_config["id"],
config=end_human_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
return (
Graph.new()
.add_root(start_node)
.add_node(llm_a, from_node_id="start")
.add_node(human_node, from_node_id="start")
.add_node(llm_b, from_node_id="llm_a")
.add_node(end_human, from_node_id="human_pause", source_handle="approve")
.build()
)
def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None:
for event in events:
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
return event
return None
def test_pause_defers_ready_nodes_until_resume() -> None:
runtime_state = _build_runtime_state()
paused_form = StaticForm(
form_id="form-pause",
rendered="rendered",
is_submitted=False,
status_value=HumanInputFormStatus.WAITING,
)
pause_repo = StaticRepo(paused_form)
mock_config = MockConfig()
mock_config.simulate_delays = True
mock_config.set_node_config(
"llm_a",
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
)
mock_config.set_node_config(
"llm_b",
NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0),
)
graph = _build_graph(runtime_state, pause_repo, mock_config)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
paused_events = list(engine.run())
assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events)
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events)
assert _get_node_started_event(paused_events, "llm_b") is None
snapshot = runtime_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
submitted_form = StaticForm(
form_id="form-pause",
rendered="rendered",
is_submitted=True,
action_id="approve",
data={},
status_value=HumanInputFormStatus.SUBMITTED,
)
resume_repo = StaticRepo(submitted_form)
resumed_graph = _build_graph(resumed_state, resume_repo, mock_config)
resumed_engine = GraphEngine(
workflow_id="workflow",
graph=resumed_graph,
graph_runtime_state=resumed_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
resumed_events = list(resumed_engine.run())
start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent))
assert start_event.reason is WorkflowStartReason.RESUMPTION
llm_b_started = _get_node_started_event(resumed_events, "llm_b")
assert llm_b_started is not None
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events)

View File

@ -0,0 +1,217 @@
import datetime
import time
from typing import Any
from unittest.mock import MagicMock
from core.workflow.entities import GraphInitParams
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.graph import GraphRunStartedEvent
from core.workflow.nodes.base.entities import OutputVariableEntity
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="test-execution-id",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = True
form_entity.selected_action_id = action_id
form_entity.submitted_data = {}
form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
repo.get_form.return_value = form_entity
return repo
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = False
repo.create_form.return_value = form_entity
repo.get_form.return_value = None
return repo
def _build_human_input_graph(
runtime_state: GraphRuntimeState,
form_repository: HumanInputFormRepository,
) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="service-api",
call_depth=0,
)
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
human_data = HumanInputNodeData(
title="human",
form_content="Awaiting human input",
inputs=[],
user_actions=[
UserAction(id="continue", title="Continue"),
],
)
human_node = HumanInputNode(
id="human",
config={"id": "human", "data": human_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
)
end_data = EndNodeData(
title="end",
outputs=[
OutputVariableEntity(variable="result", value_selector=["human", "action_id"]),
],
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
return (
Graph.new()
.add_root(start_node)
.add_node(human_node)
.add_node(end_node, from_node_id="human", source_handle="continue")
.build()
)
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]:
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
return list(engine.run())
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)]
def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None:
for event in events:
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
return event
return None
def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any:
segment = variable_pool.get(selector)
assert segment is not None
return getattr(segment, "value", segment)
def test_engine_resume_restores_state_and_completion():
# Baseline run without pausing
baseline_state = _build_runtime_state()
baseline_repo = _mock_form_repository_with_submission(action_id="continue")
baseline_graph = _build_human_input_graph(baseline_state, baseline_repo)
baseline_events = _run_graph(baseline_graph, baseline_state)
assert baseline_events
first_paused_event = baseline_events[0]
assert isinstance(first_paused_event, GraphRunStartedEvent)
assert first_paused_event.reason is WorkflowStartReason.INITIAL
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_success_nodes = _node_successes(baseline_events)
# Run with pause
paused_state = _build_runtime_state()
pause_repo = _mock_form_repository_without_submission()
paused_graph = _build_human_input_graph(paused_state, pause_repo)
paused_events = _run_graph(paused_graph, paused_state)
assert paused_events
first_paused_event = paused_events[0]
assert isinstance(first_paused_event, GraphRunStartedEvent)
assert first_paused_event.reason is WorkflowStartReason.INITIAL
assert isinstance(paused_events[-1], GraphRunPausedEvent)
snapshot = paused_state.dumps()
# Resume from snapshot
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resume_repo = _mock_form_repository_with_submission(action_id="continue")
resumed_graph = _build_human_input_graph(resumed_state, resume_repo)
resumed_events = _run_graph(resumed_graph, resumed_state)
assert resumed_events
first_resumed_event = resumed_events[0]
assert isinstance(first_resumed_event, GraphRunStartedEvent)
assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION
assert isinstance(resumed_events[-1], GraphRunSucceededEvent)
combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events)
assert combined_success_nodes == baseline_success_nodes
paused_human_started = _node_start_event(paused_events, "human")
resumed_human_started = _node_start_event(resumed_events, "human")
assert paused_human_started is not None
assert resumed_human_started is not None
assert paused_human_started.id == resumed_human_started.id
assert baseline_state.outputs == resumed_state.outputs
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
resumed_state.variable_pool, ("human", "__action_id")
)
assert baseline_state.graph_execution.completed
assert resumed_state.graph_execution.completed

View File

@ -7,6 +7,7 @@ from core.workflow.nodes.base.node import Node
# Ensures that all node classes are imported.
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed.
_ = NODE_TYPE_CLASSES_MAPPING
@ -45,7 +46,9 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
assert isinstance(cls.node_type, NodeType)
assert isinstance(node_version, str)
node_type_and_version = (node_type, node_version)
assert node_type_and_version not in type_version_set
assert node_type_and_version not in type_version_set, (
f"Duplicate node type and version for class: {cls=} {node_type_and_version=}"
)
type_version_set.add(node_type_and_version)

View File

@ -0,0 +1 @@
# Unit tests for human input node

View File

@ -0,0 +1,16 @@
from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients
from core.workflow.runtime import VariablePool
def test_render_body_template_replaces_variable_values():
config = EmailDeliveryConfig(
recipients=EmailRecipients(),
subject="Subject",
body="Hello {{#node1.value#}} {{#url#}}",
)
variable_pool = VariablePool()
variable_pool.add(["node1", "value"], "World")
result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool)
assert result == "Hello World https://example.com"

View File

@ -0,0 +1,597 @@
"""
Unit tests for human input node entities.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pydantic import ValidationError
from core.workflow.entities import GraphInitParams
from core.workflow.node_events import PauseRequestedEvent
from core.workflow.node_events.node import StreamCompletedEvent
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormInput,
FormInputDefault,
HumanInputNodeData,
MemberRecipient,
UserAction,
WebAppDeliveryMethod,
_WebAppDeliveryConfig,
)
from core.workflow.nodes.human_input.enums import (
ButtonStyle,
DeliveryMethodType,
EmailRecipientType,
FormInputType,
PlaceholderType,
TimeoutUnit,
)
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository
class TestDeliveryMethod:
"""Test DeliveryMethod entity."""
def test_webapp_delivery_method(self):
"""Test webapp delivery method creation."""
delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())
assert delivery_method.type == DeliveryMethodType.WEBAPP
assert delivery_method.enabled is True
assert isinstance(delivery_method.config, _WebAppDeliveryConfig)
def test_email_delivery_method(self):
"""Test email delivery method creation."""
recipients = EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"),
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"),
],
)
config = EmailDeliveryConfig(
recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder"
)
delivery_method = EmailDeliveryMethod(enabled=True, config=config)
assert delivery_method.type == DeliveryMethodType.EMAIL
assert delivery_method.enabled is True
assert isinstance(delivery_method.config, EmailDeliveryConfig)
assert delivery_method.config.subject == "Test Subject"
assert len(delivery_method.config.recipients.items) == 2
class TestFormInput:
"""Test FormInput entity."""
def test_text_input_with_constant_default(self):
"""Test text input with constant default value."""
default = FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter your response here...")
form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default)
assert form_input.type == FormInputType.TEXT_INPUT
assert form_input.output_variable_name == "user_input"
assert form_input.default.type == PlaceholderType.CONSTANT
assert form_input.default.value == "Enter your response here..."
def test_text_input_with_variable_default(self):
"""Test text input with variable default value."""
default = FormInputDefault(type=PlaceholderType.VARIABLE, selector=["node_123", "output_var"])
form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default)
assert form_input.default.type == PlaceholderType.VARIABLE
assert form_input.default.selector == ["node_123", "output_var"]
def test_form_input_without_default(self):
"""Test form input without default value."""
form_input = FormInput(type=FormInputType.PARAGRAPH, output_variable_name="description")
assert form_input.type == FormInputType.PARAGRAPH
assert form_input.output_variable_name == "description"
assert form_input.default is None
class TestUserAction:
"""Test UserAction entity."""
def test_user_action_creation(self):
"""Test user action creation."""
action = UserAction(id="approve", title="Approve", button_style=ButtonStyle.PRIMARY)
assert action.id == "approve"
assert action.title == "Approve"
assert action.button_style == ButtonStyle.PRIMARY
def test_user_action_default_button_style(self):
"""Test user action with default button style."""
action = UserAction(id="cancel", title="Cancel")
assert action.button_style == ButtonStyle.DEFAULT
def test_user_action_length_boundaries(self):
"""Test user action id and title length boundaries."""
action = UserAction(id="a" * 20, title="b" * 20)
assert action.id == "a" * 20
assert action.title == "b" * 20
@pytest.mark.parametrize(
("field_name", "value"),
[
("id", "a" * 21),
("title", "b" * 21),
],
)
def test_user_action_length_limits(self, field_name: str, value: str):
"""User action fields should enforce max length."""
data = {"id": "approve", "title": "Approve"}
data[field_name] = value
with pytest.raises(ValidationError) as exc_info:
UserAction(**data)
errors = exc_info.value.errors()
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
class TestHumanInputNodeData:
"""Test HumanInputNodeData entity."""
def test_valid_node_data_creation(self):
"""Test creating valid human input node data."""
delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())]
inputs = [
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="content",
default=FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter content..."),
)
]
user_actions = [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)]
node_data = HumanInputNodeData(
title="Human Input Test",
desc="Test node description",
delivery_methods=delivery_methods,
form_content="# Test Form\n\nPlease provide input:\n\n{{#$output.content#}}",
inputs=inputs,
user_actions=user_actions,
timeout=24,
timeout_unit=TimeoutUnit.HOUR,
)
assert node_data.title == "Human Input Test"
assert node_data.desc == "Test node description"
assert len(node_data.delivery_methods) == 1
assert node_data.form_content.startswith("# Test Form")
assert len(node_data.inputs) == 1
assert len(node_data.user_actions) == 1
assert node_data.timeout == 24
assert node_data.timeout_unit == TimeoutUnit.HOUR
def test_node_data_with_multiple_delivery_methods(self):
"""Test node data with multiple delivery methods."""
delivery_methods = [
WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()),
EmailDeliveryMethod(
enabled=False, # Disabled method should be fine
config=EmailDeliveryConfig(
subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True)
),
),
]
node_data = HumanInputNodeData(
title="Test Node", delivery_methods=delivery_methods, timeout=1, timeout_unit=TimeoutUnit.DAY
)
assert len(node_data.delivery_methods) == 2
assert node_data.timeout == 1
assert node_data.timeout_unit == TimeoutUnit.DAY
def test_node_data_defaults(self):
"""Test node data with default values."""
node_data = HumanInputNodeData(title="Test Node")
assert node_data.title == "Test Node"
assert node_data.desc is None
assert node_data.delivery_methods == []
assert node_data.form_content == ""
assert node_data.inputs == []
assert node_data.user_actions == []
assert node_data.timeout == 36
assert node_data.timeout_unit == TimeoutUnit.HOUR
def test_duplicate_input_output_variable_name_raises_validation_error(self):
"""Duplicate form input output_variable_name should raise validation error."""
duplicate_inputs = [
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
]
with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"):
HumanInputNodeData(title="Test Node", inputs=duplicate_inputs)
def test_duplicate_user_action_ids_raise_validation_error(self):
"""Duplicate user action ids should raise validation error."""
duplicate_actions = [
UserAction(id="submit", title="Submit"),
UserAction(id="submit", title="Submit Again"),
]
with pytest.raises(ValidationError, match="duplicated user action id 'submit'"):
HumanInputNodeData(title="Test Node", user_actions=duplicate_actions)
def test_extract_outputs_field_names(self):
content = r"""This is titile {{#start.title#}}
A content is required:
{{#$output.content#}}
A ending is required:
{{#$output.ending#}}
"""
node_data = HumanInputNodeData(title="Human Input", form_content=content)
field_names = node_data.outputs_field_names()
assert field_names == ["content", "ending"]
class TestRecipients:
"""Test email recipient entities."""
def test_member_recipient(self):
"""Test member recipient creation."""
recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")
assert recipient.type == EmailRecipientType.MEMBER
assert recipient.user_id == "user-123"
def test_external_recipient(self):
"""Test external recipient creation."""
recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com")
assert recipient.type == EmailRecipientType.EXTERNAL
assert recipient.email == "test@example.com"
def test_email_recipients_whole_workspace(self):
"""Test email recipients with whole workspace enabled."""
recipients = EmailRecipients(
whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")]
)
assert recipients.whole_workspace is True
assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True
def test_email_recipients_specific_users(self):
"""Test email recipients with specific users."""
recipients = EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"),
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"),
],
)
assert recipients.whole_workspace is False
assert len(recipients.items) == 2
assert recipients.items[0].user_id == "user-123"
assert recipients.items[1].email == "external@example.com"
class TestHumanInputNodeVariableResolution:
"""Tests for resolving variable-based defaults in HumanInputNode."""
def test_resolves_variable_defaults(self):
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
variable_pool.add(("start", "name"), "Jane Doe")
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
node_data = HumanInputNodeData(
title="Human Input",
form_content="Provide your name",
inputs=[
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="user_name",
default=FormInputDefault(type=PlaceholderType.VARIABLE, selector=["start", "name"]),
),
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="user_email",
default=FormInputDefault(type=PlaceholderType.CONSTANT, value="foo@example.com"),
),
],
user_actions=[UserAction(id="submit", title="Submit")],
)
config = {"id": "human", "data": node_data.model_dump()}
mock_repo = MagicMock(spec=HumanInputFormRepository)
mock_repo.get_form.return_value = None
mock_repo.create_form.return_value = SimpleNamespace(
id="form-1",
rendered_content="Provide your name",
web_app_token="token",
recipients=[],
submitted=False,
)
node = HumanInputNode(
id=config["id"],
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=mock_repo,
)
run_result = node._run()
pause_event = next(run_result)
assert isinstance(pause_event, PauseRequestedEvent)
expected_values = {"user_name": "Jane Doe"}
assert pause_event.reason.resolved_default_values == expected_values
params = mock_repo.create_form.call_args.args[0]
assert params.resolved_default_values == expected_values
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-2",
),
user_inputs={},
conversation_variables=[],
)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
node_data = HumanInputNodeData(
title="Human Input",
form_content="Provide your name",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
)
config = {"id": "human", "data": node_data.model_dump()}
mock_repo = MagicMock(spec=HumanInputFormRepository)
mock_repo.get_form.return_value = None
mock_repo.create_form.return_value = SimpleNamespace(
id="form-2",
rendered_content="Provide your name",
web_app_token="console-token",
recipients=[SimpleNamespace(token="recipient-token")],
submitted=False,
)
node = HumanInputNode(
id=config["id"],
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=mock_repo,
)
run_result = node._run()
pause_event = next(run_result)
assert isinstance(pause_event, PauseRequestedEvent)
assert pause_event.reason.form_token == "console-token"
def test_debugger_debug_mode_overrides_email_recipients(self):
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user-123",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-3",
),
user_inputs={},
conversation_variables=[],
)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user-123",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
node_data = HumanInputNodeData(
title="Human Input",
form_content="Provide your name",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
delivery_methods=[
EmailDeliveryMethod(
enabled=True,
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")],
),
subject="Subject",
body="Body",
debug_mode=True,
),
)
],
)
config = {"id": "human", "data": node_data.model_dump()}
mock_repo = MagicMock(spec=HumanInputFormRepository)
mock_repo.get_form.return_value = None
mock_repo.create_form.return_value = SimpleNamespace(
id="form-3",
rendered_content="Provide your name",
web_app_token="token",
recipients=[],
submitted=False,
)
node = HumanInputNode(
id=config["id"],
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=mock_repo,
)
run_result = node._run()
pause_event = next(run_result)
assert isinstance(pause_event, PauseRequestedEvent)
params = mock_repo.create_form.call_args.args[0]
assert len(params.delivery_methods) == 1
method = params.delivery_methods[0]
assert isinstance(method, EmailDeliveryMethod)
assert method.config.debug_mode is True
assert method.config.recipients.whole_workspace is False
assert len(method.config.recipients.items) == 1
recipient = method.config.recipients.items[0]
assert isinstance(recipient, MemberRecipient)
assert recipient.user_id == "user-123"
class TestValidation:
"""Test validation scenarios."""
def test_invalid_form_input_type(self):
"""Test validation with invalid form input type."""
with pytest.raises(ValidationError):
FormInput(
type="invalid-type", # Invalid type
output_variable_name="test",
)
def test_invalid_button_style(self):
"""Test validation with invalid button style."""
with pytest.raises(ValidationError):
UserAction(
id="test",
title="Test",
button_style="invalid-style", # Invalid style
)
def test_invalid_timeout_unit(self):
"""Test validation with invalid timeout unit."""
with pytest.raises(ValidationError):
HumanInputNodeData(
title="Test",
timeout_unit="invalid-unit", # Invalid unit
)
class TestHumanInputNodeRenderedContent:
"""Tests for rendering submitted content."""
def test_replaces_outputs_placeholders_after_submission(self):
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
node_data = HumanInputNodeData(
title="Human Input",
form_content="Name: {{#$output.name#}}",
inputs=[
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="name",
)
],
user_actions=[UserAction(id="approve", title="Approve")],
)
config = {"id": "human", "data": node_data.model_dump()}
form_repository = InMemoryHumanInputFormRepository()
node = HumanInputNode(
id=config["id"],
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
)
pause_gen = node._run()
pause_event = next(pause_gen)
assert isinstance(pause_event, PauseRequestedEvent)
with pytest.raises(StopIteration):
next(pause_gen)
form_repository.set_submission(action_id="approve", form_data={"name": "Alice"})
events = list(node._run())
last_event = events[-1]
assert isinstance(last_event, StreamCompletedEvent)
node_run_result = last_event.node_run_result
assert node_run_result.outputs["__rendered_content"] == "Name: Alice"

View File

@ -0,0 +1,172 @@
import datetime
from types import SimpleNamespace
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import NodeType
from core.workflow.graph_events import (
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunStartedEvent,
)
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
class _FakeFormRepository:
def __init__(self, form):
self._form = form
def get_form(self, *_args, **_kwargs):
return self._form
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = SystemVariable.default()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
start_at=0.0,
)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
config = {
"id": "node-1",
"type": NodeType.HUMAN_INPUT.value,
"data": {
"title": "Human Input",
"form_content": form_content,
"inputs": [
{
"type": "text_input",
"output_variable_name": "name",
"default": {"type": "constant", "value": ""},
}
],
"user_actions": [
{
"id": "Accept",
"title": "Approve",
"button_style": "default",
}
],
},
}
fake_form = SimpleNamespace(
id="form-1",
rendered_content=form_content,
submitted=True,
selected_action_id="Accept",
submitted_data={"name": "Alice"},
status=HumanInputFormStatus.SUBMITTED,
expiration_time=naive_utc_now() + datetime.timedelta(days=1),
)
repo = _FakeFormRepository(fake_form)
return HumanInputNode(
id="node-1",
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
)
def _build_timeout_node() -> HumanInputNode:
system_variables = SystemVariable.default()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
start_at=0.0,
)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
config = {
"id": "node-1",
"type": NodeType.HUMAN_INPUT.value,
"data": {
"title": "Human Input",
"form_content": "Please enter your name:\n\n{{#$output.name#}}",
"inputs": [
{
"type": "text_input",
"output_variable_name": "name",
"default": {"type": "constant", "value": ""},
}
],
"user_actions": [
{
"id": "Accept",
"title": "Approve",
"button_style": "default",
}
],
},
}
fake_form = SimpleNamespace(
id="form-1",
rendered_content="content",
submitted=False,
selected_action_id=None,
submitted_data=None,
status=HumanInputFormStatus.TIMEOUT,
expiration_time=naive_utc_now() - datetime.timedelta(minutes=1),
)
repo = _FakeFormRepository(fake_form)
return HumanInputNode(
id="node-1",
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
)
def test_human_input_node_emits_form_filled_event_before_succeeded():
node = _build_node()
events = list(node.run())
assert isinstance(events[0], NodeRunStartedEvent)
assert isinstance(events[1], NodeRunHumanInputFormFilledEvent)
filled_event = events[1]
assert filled_event.node_title == "Human Input"
assert filled_event.rendered_content.endswith("Alice")
assert filled_event.action_id == "Accept"
assert filled_event.action_text == "Approve"
def test_human_input_node_emits_timeout_event_before_succeeded():
node = _build_timeout_node()
events = list(node.run())
assert isinstance(events[0], NodeRunStartedEvent)
assert isinstance(events[1], NodeRunHumanInputFormTimeoutEvent)
timeout_event = events[1]
assert timeout_event.node_title == "Human Input"