WIP: resume

This commit is contained in:
QuantumGhost
2025-11-21 10:13:20 +08:00
parent c0e15b9e1b
commit c0f1aeddbe
49 changed files with 2160 additions and 1445 deletions

View File

@ -0,0 +1,278 @@
import sys
import time
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
API_DIR = str(Path(__file__).resolve().parents[5])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)
import core.workflow.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.entities.commands import PauseCommand
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
if "core.ops.ops_trace_manager" not in sys.modules:
ops_stub = ModuleType("core.ops.ops_trace_manager")
class _StubTraceQueueManager:
def __init__(self, *_, **__):
pass
ops_stub.TraceQueueManager = _StubTraceQueueManager
sys.modules["core.ops.ops_trace_manager"] = ops_stub
class _StubToolNodeData(BaseNodeData):
pass
class _StubToolNode(Node):
node_type = NodeType.TOOL
def init_node_data(self, data):
self._node_data = _StubToolNodeData.model_validate(data)
def _get_error_strategy(self):
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self):
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"value": f"{self.id}-done"},
)
def _patch_tool_node(mocker):
from core.workflow.nodes import node_factory
custom_mapping = dict(node_factory.NODE_TYPE_CLASSES_MAPPING)
custom_versions = dict(custom_mapping[NodeType.TOOL])
custom_versions[node_factory.LATEST_VERSION] = _StubToolNode
custom_mapping[NodeType.TOOL] = custom_versions
mocker.patch("core.workflow.nodes.node_factory.NODE_TYPE_CLASSES_MAPPING", custom_mapping)
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={},
user_id="user",
user_from="account",
invoke_from="service-api",
call_depth=0,
)
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
start_node.init_node_data(start_data.model_dump())
tool_data = _StubToolNodeData(title="tool")
tool_a = _StubToolNode(
id="tool_a",
config={"id": "tool_a", "data": tool_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
tool_a.init_node_data(tool_data.model_dump())
tool_b = _StubToolNode(
id="tool_b",
config={"id": "tool_b", "data": tool_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
tool_b.init_node_data(tool_data.model_dump())
tool_c = _StubToolNode(
id="tool_c",
config={"id": "tool_c", "data": tool_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
tool_c.init_node_data(tool_data.model_dump())
end_data = EndNodeData(
title="end",
outputs=[VariableSelector(variable="result", value_selector=["tool_c", "value"])],
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
end_node.init_node_data(end_data.model_dump())
return (
Graph.new()
.add_root(start_node)
.add_node(tool_a)
.add_node(tool_b)
.add_node(tool_c)
.add_node(end_node)
.add_edge("tool_a", "tool_b")
.add_edge("tool_b", "tool_c")
.add_edge("tool_c", "end")
.build()
)
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
variable_pool.system_variables.workflow_execution_id = run_id
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
command_channel = InMemoryChannel()
graph = _build_graph(runtime_state)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=command_channel,
)
events: list[GraphEngineEvent] = []
for event in engine.run():
if isinstance(event, NodeRunSucceededEvent) and pause_on and event.node_id == pause_on:
command_channel.send_command(PauseCommand(reason="test pause"))
engine._command_processor.process_commands() # type: ignore[attr-defined]
events.append(event)
return events
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
def test_workflow_app_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("paused-run")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = wf_app_gen_module.WorkflowAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
graph_runtime_state=resumed_state,
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs
def test_advanced_chat_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("adv-baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("adv-paused")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = adv_app_gen_module.AdvancedChatAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
conversation=SimpleNamespace(id="conv"),
message=SimpleNamespace(id="msg"),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
graph_runtime_state=resumed_state,
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs

View File

@ -61,8 +61,8 @@ class ConcurrentPublisher:
messages.append(message)
if self.delay > 0:
time.sleep(self.delay)
except Exception as e:
_logger.error("Publisher %s error: %s", thread_id, e)
except Exception:
_logger.exception("Pubmsg=lisher %s", thread_id)
with self._lock:
self.published_messages.append(messages)
@ -308,8 +308,8 @@ def measure_throughput(
try:
operation()
count += 1
except Exception as e:
_logger.error("Operation failed: %s", e)
except Exception:
_logger.exception("Operation failed")
break
elapsed = time.time() - start_time

View File

@ -1,4 +1,26 @@
import sys
from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock
if "core.ops.ops_trace_manager" not in sys.modules:
stub_module = ModuleType("core.ops.ops_trace_manager")
class _StubTraceQueueManager:
def __init__(self, *_, **__):
pass
stub_module.TraceQueueManager = _StubTraceQueueManager
sys.modules["core.ops.ops_trace_manager"] = stub_module
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from tests.unit_tests.core.workflow.graph_engine.test_pause_resume_state import (
_build_pausing_graph,
_build_runtime_state,
_node_successes,
_PausingNode,
_PausingNodeData,
_run_graph,
)
def test_should_prepare_user_inputs_defaults_to_true():
@ -17,3 +39,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
def test_resume_delegates_to_generate(mocker):
generator = WorkflowAppGenerator()
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")
application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger")
runtime_state = MagicMock(name="runtime-state")
pause_config = MagicMock(name="pause-config")
result = generator.resume(
app_model=MagicMock(),
workflow=MagicMock(),
user=MagicMock(),
application_generate_entity=application_generate_entity,
graph_runtime_state=runtime_state,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
graph_engine_layers=("layer",),
pause_state_config=pause_config,
variable_loader=MagicMock(),
)
assert result == "ok"
mock_generate.assert_called_once()
kwargs = mock_generate.call_args.kwargs
assert kwargs["graph_runtime_state"] is runtime_state
assert kwargs["pause_state_config"] is pause_config
assert kwargs["streaming"] is False
assert kwargs["invoke_from"] == "debugger"
def test_generate_appends_pause_layer_and_forwards_state(mocker):
generator = WorkflowAppGenerator()
mock_queue_manager = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager)
fake_current_app = MagicMock()
fake_current_app._get_current_object.return_value = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app)
mocker.patch(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
return_value="converted",
)
mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response")
mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock())
pause_layer = MagicMock(name="pause-layer")
mocker.patch(
"core.app.apps.workflow.app_generator.PauseStatePersistenceLayer",
return_value=pause_layer,
)
dummy_session = MagicMock()
dummy_session.close = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session)
worker_kwargs: dict[str, object] = {}
class DummyThread:
def __init__(self, target, kwargs):
worker_kwargs["target"] = target
worker_kwargs["kwargs"] = kwargs
def start(self):
return None
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread)
app_model = SimpleNamespace(mode="workflow")
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf")
application_generate_entity = SimpleNamespace(
task_id="task",
user_id="user",
invoke_from="service-api",
app_config=app_config,
files=[],
stream=True,
workflow_execution_id="run",
)
graph_runtime_state = MagicMock()
result = generator._generate(
app_model=app_model,
workflow=MagicMock(),
user=MagicMock(),
application_generate_entity=application_generate_entity,
invoke_from="service-api",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
streaming=True,
graph_engine_layers=("base-layer",),
graph_runtime_state=graph_runtime_state,
pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"),
)
assert result == "converted"
assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer)
assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state
def test_resume_path_runs_worker_with_runtime_state(mocker):
generator = WorkflowAppGenerator()
runtime_state = MagicMock(name="runtime-state")
pause_layer = MagicMock(name="pause-layer")
mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer)
queue_manager = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager)
mocker.patch.object(generator, "_handle_response", return_value="raw-response")
mocker.patch(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
side_effect=lambda response, invoke_from: response,
)
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
workflow = SimpleNamespace(
id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1"
)
end_user = SimpleNamespace(session_id="end-user-session")
app_record = SimpleNamespace(id="app")
session = MagicMock()
session.__enter__.return_value = session
session.__exit__.return_value = False
session.scalar.side_effect = [workflow, end_user, app_record]
mocker.patch("core.app.apps.workflow.app_generator.Session", return_value=session)
runner_instance = MagicMock()
def runner_ctor(**kwargs):
assert kwargs["graph_runtime_state"] is runtime_state
return runner_instance
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor)
class ImmediateThread:
def __init__(self, target, kwargs):
target(**kwargs)
def start(self):
return None
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread)
mocker.patch(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=MagicMock(),
)
pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner")
app_model = SimpleNamespace(mode="workflow")
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow")
application_generate_entity = SimpleNamespace(
task_id="task",
user_id="user",
invoke_from="service-api",
app_config=app_config,
files=[],
stream=True,
workflow_execution_id="run",
trace_manager=MagicMock(),
)
result = generator.resume(
app_model=app_model,
workflow=workflow,
user=MagicMock(),
application_generate_entity=application_generate_entity,
graph_runtime_state=runtime_state,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
pause_state_config=pause_config,
)
assert result == "raw-response"
runner_instance.run.assert_called_once()
queue_manager.graph_runtime_state = runtime_state

View File

@ -1,317 +0,0 @@
"""
Tests for HumanInputForm domain model and repository.
"""
import json
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from core.repositories.sqlalchemy_human_input_form_repository import SQLAlchemyHumanInputFormRepository
from core.workflow.entities.human_input_form import HumanInputForm, HumanInputFormStatus
class TestHumanInputForm:
"""Test cases for HumanInputForm domain model."""
def test_create_form(self):
"""Test creating a new form."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
assert form.id_ == "test-form-id"
assert form.workflow_run_id == "test-workflow-run"
assert form.status == HumanInputFormStatus.WAITING
assert form.can_be_submitted
assert not form.is_submitted
assert not form.is_expired
assert form.is_waiting
def test_submit_form(self):
"""Test submitting a form."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.submit(
data={"field1": "value1"},
action="submit",
submission_user_id="user123",
)
assert form.is_submitted
assert not form.can_be_submitted
assert form.status == HumanInputFormStatus.SUBMITTED
assert form.submission is not None
assert form.submission.data == {"field1": "value1"}
assert form.submission.action == "submit"
assert form.submission.submission_user_id == "user123"
def test_submit_form_invalid_action(self):
"""Test submitting a form with invalid action."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
with pytest.raises(ValueError, match="Invalid action: invalid_action"):
form.submit(data={}, action="invalid_action")
def test_submit_expired_form(self):
"""Test submitting an expired form should fail."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.expire()
with pytest.raises(ValueError, match="Form cannot be submitted in status: expired"):
form.submit(data={}, action="submit")
def test_expire_form(self):
"""Test expiring a form."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.expire()
assert form.is_expired
assert not form.can_be_submitted
assert form.status == HumanInputFormStatus.EXPIRED
def test_expire_already_submitted_form(self):
"""Test expiring an already submitted form should fail."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.submit(data={}, action="submit")
with pytest.raises(ValueError, match="Form cannot be expired in status: submitted"):
form.expire()
def test_get_form_definition_for_display(self):
"""Test getting form definition for display."""
form_definition = {
"inputs": [{"type": "text", "name": "field1"}],
"user_actions": [{"id": "submit", "title": "Submit"}],
}
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition=form_definition,
rendered_content="<form>Test form</form>",
)
result = form.get_form_definition_for_display()
assert result["form_content"] == "<form>Test form</form>"
assert result["inputs"] == form_definition["inputs"]
assert result["user_actions"] == form_definition["user_actions"]
assert "site" not in result
def test_get_form_definition_for_display_with_site_info(self):
"""Test getting form definition for display with site info."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": []},
rendered_content="<form>Test form</form>",
)
result = form.get_form_definition_for_display(include_site_info=True)
assert "site" in result
assert result["site"]["title"] == "Workflow Form"
def test_get_form_definition_expired_form(self):
"""Test getting form definition for expired form should fail."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": []},
rendered_content="<form>Test form</form>",
)
form.expire()
with pytest.raises(ValueError, match="Form has expired"):
form.get_form_definition_for_display()
def test_get_form_definition_submitted_form(self):
"""Test getting form definition for submitted form should fail."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.submit(data={}, action="submit")
with pytest.raises(ValueError, match="Form has already been submitted"):
form.get_form_definition_for_display()
class TestSQLAlchemyHumanInputFormRepository:
"""Test cases for SQLAlchemyHumanInputFormRepository."""
@pytest.fixture
def mock_session_factory(self):
"""Create a mock session factory."""
session = MagicMock()
session_factory = MagicMock()
session_factory.return_value.__enter__.return_value = session
session_factory.return_value.__exit__.return_value = None
return session_factory
@pytest.fixture
def mock_user(self):
"""Create a mock user."""
user = MagicMock()
user.current_tenant_id = "test-tenant-id"
user.id = "test-user-id"
return user
@pytest.fixture
def repository(self, mock_session_factory, mock_user):
"""Create a repository instance."""
return SQLAlchemyHumanInputFormRepository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
)
def test_to_domain_model(self, repository):
"""Test converting DB model to domain model."""
from models.human_input import (
HumanInputForm as DBForm,
)
from models.human_input import (
HumanInputFormStatus as DBStatus,
)
from models.human_input import (
HumanInputSubmissionType as DBSubmissionType,
)
db_form = DBForm()
db_form.id = "test-id"
db_form.workflow_run_id = "test-workflow"
db_form.form_definition = json.dumps({"inputs": [], "user_actions": []})
db_form.rendered_content = "<form>Test</form>"
db_form.status = DBStatus.WAITING
db_form.web_app_token = "test-token"
db_form.created_at = datetime.utcnow()
db_form.submitted_data = json.dumps({"field": "value"})
db_form.submitted_at = datetime.utcnow()
db_form.submission_type = DBSubmissionType.web_form
db_form.submission_user_id = "user123"
domain_form = repository._to_domain_model(db_form)
assert domain_form.id_ == "test-id"
assert domain_form.workflow_run_id == "test-workflow"
assert domain_form.form_definition == {"inputs": [], "user_actions": []}
assert domain_form.rendered_content == "<form>Test</form>"
assert domain_form.status == HumanInputFormStatus.WAITING
assert domain_form.web_app_token == "test-token"
assert domain_form.submission is not None
assert domain_form.submission.data == {"field": "value"}
assert domain_form.submission.submission_user_id == "user123"
def test_to_db_model(self, repository):
"""Test converting domain model to DB model."""
from models.human_input import (
HumanInputFormStatus as DBStatus,
)
domain_form = HumanInputForm.create(
id_="test-id",
workflow_run_id="test-workflow",
form_definition={"inputs": [], "user_actions": []},
rendered_content="<form>Test</form>",
web_app_token="test-token",
)
db_form = repository._to_db_model(domain_form)
assert db_form.id == "test-id"
assert db_form.tenant_id == "test-tenant-id"
assert db_form.app_id == "test-app-id"
assert db_form.workflow_run_id == "test-workflow"
assert json.loads(db_form.form_definition) == {"inputs": [], "user_actions": []}
assert db_form.rendered_content == "<form>Test</form>"
assert db_form.status == DBStatus.WAITING
assert db_form.web_app_token == "test-token"
def test_save(self, repository, mock_session_factory):
"""Test saving a form."""
session = mock_session_factory.return_value.__enter__.return_value
domain_form = HumanInputForm.create(
id_="test-id",
workflow_run_id="test-workflow",
form_definition={"inputs": []},
rendered_content="<form>Test</form>",
)
repository.save(domain_form)
session.merge.assert_called_once()
session.commit.assert_called_once()
def test_get_by_id(self, repository, mock_session_factory):
"""Test getting a form by ID."""
session = mock_session_factory.return_value.__enter__.return_value
mock_db_form = MagicMock()
mock_db_form.id = "test-id"
session.scalar.return_value = mock_db_form
with patch.object(repository, "_to_domain_model") as mock_convert:
domain_form = HumanInputForm.create(
id_="test-id",
workflow_run_id="test-workflow",
form_definition={"inputs": []},
rendered_content="<form>Test</form>",
)
mock_convert.return_value = domain_form
result = repository.get_by_id("test-id")
assert result == domain_form
session.scalar.assert_called_once()
mock_convert.assert_called_once_with(mock_db_form)
def test_get_by_id_not_found(self, repository, mock_session_factory):
"""Test getting a non-existent form by ID."""
session = mock_session_factory.return_value.__enter__.return_value
session.scalar.return_value = None
with pytest.raises(ValueError, match="Human input form not found: test-id"):
repository.get_by_id("test-id")
def test_mark_expired_forms(self, repository, mock_session_factory):
"""Test marking expired forms."""
session = mock_session_factory.return_value.__enter__.return_value
mock_forms = [MagicMock(), MagicMock(), MagicMock()]
session.scalars.return_value.all.return_value = mock_forms
result = repository.mark_expired_forms(expiry_hours=24)
assert result == 3
for form in mock_forms:
assert hasattr(form, "status")
session.commit.assert_called_once()

View File

@ -2,16 +2,27 @@
from __future__ import annotations
import dataclasses
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_reposotiry import (
HumanInputFormReadRepository,
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
_WorkspaceMemberInfo,
)
from core.workflow.nodes.human_input.entities import ExternalRecipient, MemberRecipient
from core.workflow.nodes.human_input.entities import (
ExternalRecipient,
FormDefinition,
MemberRecipient,
TimeoutUnit,
UserAction,
)
from libs.datetime_utils import naive_utc_now
from models.human_input import (
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
@ -41,6 +52,23 @@ def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleName
return created
@pytest.fixture(autouse=True)
def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None:
"""Avoid SQLAlchemy mapper configuration in tests using fake sessions."""
class _FakeSelect:
def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
return self
def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
return self
monkeypatch.setattr(
"core.repositories.human_input_reposotiry.selectinload", lambda *args, **kwargs: "_loader_option"
)
monkeypatch.setattr("core.repositories.human_input_reposotiry.select", lambda *args, **kwargs: _FakeSelect())
class TestHumanInputFormRepositoryImplHelpers:
def test_create_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
@ -125,3 +153,201 @@ class TestHumanInputFormRepositoryImplHelpers:
assert len(recipients) == 2
emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients}
assert emails == {"member1@example.com", "member2@example.com"}
def _make_form_definition() -> str:
return FormDefinition(
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
rendered_content="<p>hello</p>",
timeout=1,
timeout_unit=TimeoutUnit.HOUR,
).model_dump_json()
@dataclasses.dataclass
class _DummyForm:
id: str
workflow_run_id: str
node_id: str
tenant_id: str
form_definition: str
rendered_content: str
expiration_time: datetime
selected_action_id: str | None = None
submitted_data: str | None = None
submitted_at: datetime | None = None
submission_user_id: str | None = None
submission_end_user_id: str | None = None
completed_by_recipient_id: str | None = None
@dataclasses.dataclass
class _DummyRecipient:
id: str
form_id: str
recipient_type: RecipientType
access_token: str
form: _DummyForm | None = None
class _FakeScalarResult:
def __init__(self, obj):
self._obj = obj
def first(self):
return self._obj
class _FakeSession:
def __init__(
self,
*,
scalars_result=None,
forms: dict[str, _DummyForm] | None = None,
recipients: dict[str, _DummyRecipient] | None = None,
):
self._scalars_result = scalars_result
self.forms = forms or {}
self.recipients = recipients or {}
def scalars(self, _query):
return _FakeScalarResult(self._scalars_result)
def get(self, model_cls, obj_id): # type: ignore[no-untyped-def]
if getattr(model_cls, "__name__", None) == "HumanInputForm":
return self.forms.get(obj_id)
if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient":
return self.recipients.get(obj_id)
return None
def add(self, _obj):
return None
def flush(self):
return None
def refresh(self, _obj):
return None
def begin(self):
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def _session_factory(session: _FakeSession):
class _SessionContext:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return None
def _factory(*_args, **_kwargs):
return _SessionContext()
return _factory
class TestHumanInputFormReadRepository:
def test_get_by_token_returns_record(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.WEBAPP,
access_token="token-123",
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormReadRepository(_session_factory(session))
record = repo.get_by_token("token-123")
assert record is not None
assert record.form_id == form.id
assert record.recipient_type == RecipientType.WEBAPP
assert record.submitted is False
def test_get_by_form_id_and_recipient_type_uses_recipient(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.WEBAPP,
access_token="token-123",
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormReadRepository(_session_factory(session))
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.WEBAPP)
assert record is not None
assert record.recipient_id == recipient.id
assert record.access_token == recipient.access_token
def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch):
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_reposotiry.naive_utc_now", lambda: fixed_now)
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=fixed_now,
)
recipient = _DummyRecipient(
id="recipient-1",
form_id="form-1",
recipient_type=RecipientType.WEBAPP,
access_token="token-123",
)
session = _FakeSession(
forms={form.id: form},
recipients={recipient.id: recipient},
)
repo = HumanInputFormReadRepository(_session_factory(session))
record: HumanInputFormRecord = repo.mark_submitted(
form_id=form.id,
recipient_id=recipient.id,
selected_action_id="approve",
form_data={"field": "value"},
submission_user_id="user-1",
submission_end_user_id="end-user-1",
)
assert form.selected_action_id == "approve"
assert form.completed_by_recipient_id == recipient.id
assert form.submission_user_id == "user-1"
assert form.submission_end_user_id == "end-user-1"
assert form.submitted_at == fixed_now
assert record.submitted is True
assert record.selected_action_id == "approve"
assert record.submitted_data == {"field": "value"}

View File

@ -28,10 +28,11 @@ class TestPauseReasonDiscriminator:
{
"reason": {
"TYPE": "human_input_required",
"human_input_form_id": "form_id",
"form_id": "form_id",
"form_content": "form_content",
},
},
HumanInputRequired(human_input_form_id="form_id"),
HumanInputRequired(form_id="form_id", form_content="form_content"),
id="HumanInputRequired",
),
pytest.param(
@ -56,7 +57,7 @@ class TestPauseReasonDiscriminator:
@pytest.mark.parametrize(
"reason",
[
HumanInputRequired(human_input_form_id="form_id"),
HumanInputRequired(form_id="form_id", form_content="form_content"),
SchedulingPause(message="Hold on"),
],
ids=lambda x: type(x).__name__,

View File

@ -764,203 +764,3 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
assert partial_event.outputs.get("answer") == "fallback response"
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)
def test_suspend_and_resume():
graph_config = {
"edges": [
{
"data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"},
"id": "1753041723554-source-1753041730748-target",
"source": "1753041723554",
"sourceHandle": "source",
"target": "1753041730748",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"},
"id": "1753041730748-true-answer-target",
"source": "1753041730748",
"sourceHandle": "true",
"target": "answer",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {
"isInIteration": False,
"isInLoop": False,
"sourceType": "if-else",
"targetType": "answer",
},
"id": "1753041730748-false-1753041952799-target",
"source": "1753041730748",
"sourceHandle": "false",
"target": "1753041952799",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
],
"nodes": [
{
"data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []},
"height": 54,
"id": "1753041723554",
"position": {"x": 32, "y": 282},
"positionAbsolute": {"x": 32, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d",
"value": "a",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"desc": "",
"selected": False,
"title": "IF/ELSE",
"type": "if-else",
},
"height": 126,
"id": "1753041730748",
"position": {"x": 368, "y": 282},
"positionAbsolute": {"x": 368, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "A",
"desc": "",
"selected": False,
"title": "Answer A",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "answer",
"position": {"x": 746, "y": 282},
"positionAbsolute": {"x": 746, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "Else",
"desc": "",
"selected": False,
"title": "Answer Else",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "1753041952799",
"position": {"x": 746, "y": 426},
"positionAbsolute": {"x": 746, "y": 426},
"selected": True,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
],
"viewport": {"x": -420, "y": -76.5, "zoom": 1},
}
graph = Graph.init(graph_config)
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="hello",
conversation_id="abababa",
),
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
max_execution_steps=500,
max_execution_time=1200,
)
_IF_ELSE_NODE_ID = "1753041730748"
def command_source(params: CommandParams) -> CommandTypes:
# requires the engine to suspend before the execution
# of If-Else node.
if params.next_node.node_id == _IF_ELSE_NODE_ID:
return SuspendCommand()
else:
return ContinueCommand()
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
command_source=command_source,
)
events = list(graph_engine.run())
last_event = events[-1]
assert isinstance(last_event, GraphRunSuspendedEvent)
assert last_event.current_node_id == _IF_ELSE_NODE_ID
state = graph_engine.save()
assert state != ""
engine2 = GraphEngine.resume(
state=state,
graph=graph,
)
events = list(engine2.run())
assert isinstance(events[-1], GraphRunSucceededEvent)
node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)]
assert node_run_succeeded_events
start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"]
assert not start_events
ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID]
assert ifelse_succeeded_events
answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"]
assert answer_else_events
assert answer_else_events[0].route_node_state.node_run_result.outputs == {
"answer": "Else",
"files": ArrayFileSegment(value=[]),
}
answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"]
assert not answer_a_events

View File

@ -17,8 +17,8 @@ from core.workflow.graph_events import (
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,

View File

@ -16,8 +16,8 @@ from core.workflow.graph_events import (
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,

View File

@ -0,0 +1,185 @@
import time
from collections.abc import Generator, Mapping
from typing import Any
import core.workflow.nodes.human_input.entities # noqa: F401
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeEventBase, NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
class _PausingNodeData(BaseNodeData):
pass
class _PausingNode(Node):
node_type = NodeType.TOOL
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = _PausingNodeData.model_validate(data)
def _get_error_strategy(self):
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@staticmethod
def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]:
yield event
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
resumed_flag = self.graph_runtime_state.variable_pool.get((self.id, "resumed"))
if resumed_flag is None:
# mark as resumed and request pause
self.graph_runtime_state.variable_pool.add((self.id, "resumed"), True)
return self._pause_generator(PauseRequestedEvent(reason=SchedulingPause(message="test pause")))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"value": "completed"},
)
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="service-api",
call_depth=0,
)
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
start_node.init_node_data(start_data.model_dump())
pause_data = _PausingNodeData(title="pausing")
pause_node = _PausingNode(
id="pausing",
config={"id": "pausing", "data": pause_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
pause_node.init_node_data(pause_data.model_dump())
end_data = EndNodeData(
title="end",
outputs=[
VariableSelector(variable="result", value_selector=["pausing", "value"]),
],
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
end_node.init_node_data(end_data.model_dump())
return Graph.new().add_root(start_node).add_node(pause_node).add_node(end_node).build()
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]:
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
return list(engine.run())
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)]
def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any:
segment = variable_pool.get(selector)
assert segment is not None
return getattr(segment, "value", segment)
def test_engine_resume_restores_state_and_completion():
# Baseline run without pausing
baseline_state = _build_runtime_state()
baseline_graph = _build_pausing_graph(baseline_state)
baseline_state.variable_pool.add(("pausing", "resumed"), True)
baseline_events = _run_graph(baseline_graph, baseline_state)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_success_nodes = _node_successes(baseline_events)
# Run with pause
paused_state = _build_runtime_state()
paused_graph = _build_pausing_graph(paused_state)
paused_events = _run_graph(paused_graph, paused_state)
assert isinstance(paused_events[-1], GraphRunPausedEvent)
snapshot = paused_state.dumps()
# Resume from snapshot
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resumed_graph = _build_pausing_graph(resumed_state)
resumed_events = _run_graph(resumed_graph, resumed_state)
assert isinstance(resumed_events[-1], GraphRunSucceededEvent)
combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events)
assert combined_success_nodes == baseline_success_nodes
assert baseline_state.outputs == resumed_state.outputs
assert _segment_value(baseline_state.variable_pool, ("pausing", "resumed")) == _segment_value(
resumed_state.variable_pool, ("pausing", "resumed")
)
assert _segment_value(baseline_state.variable_pool, ("pausing", "value")) == _segment_value(
resumed_state.variable_pool, ("pausing", "value")
)
assert baseline_state.graph_execution.completed
assert resumed_state.graph_execution.completed

View File

@ -1,283 +0,0 @@
"""
Unit tests for human input node implementation.
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.human_input import HumanInputNode, HumanInputNodeData
class TestHumanInputNode:
"""Test HumanInputNode implementation."""
@pytest.fixture
def mock_graph_init_params(self):
"""Create mock graph initialization parameters."""
mock_params = Mock()
mock_params.tenant_id = "tenant-123"
mock_params.app_id = "app-456"
mock_params.user_id = "user-789"
mock_params.user_from = "web"
mock_params.invoke_from = "web_app"
mock_params.call_depth = 0
return mock_params
@pytest.fixture
def mock_graph(self):
"""Create mock graph."""
return Mock()
@pytest.fixture
def mock_graph_runtime_state(self):
"""Create mock graph runtime state."""
return Mock()
@pytest.fixture
def sample_node_config(self):
"""Create sample node configuration."""
return {
"id": "human_input_123",
"data": {
"title": "User Confirmation",
"desc": "Please confirm the action",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}",
"inputs": [
{
"type": "text-input",
"output_variable_name": "confirmation",
"placeholder": {"type": "constant", "value": "Type 'yes' to confirm"},
}
],
"user_actions": [
{"id": "confirm", "title": "Confirm", "button_style": "primary"},
{"id": "cancel", "title": "Cancel", "button_style": "default"},
],
"timeout": 24,
"timeout_unit": "hour",
},
}
@pytest.fixture
def human_input_node(self, sample_node_config, mock_graph_init_params, mock_graph, mock_graph_runtime_state):
"""Create HumanInputNode instance."""
node = HumanInputNode(
id="node_123",
config=sample_node_config,
graph_init_params=mock_graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
return node
def test_node_initialization(self, human_input_node):
"""Test node initialization."""
assert human_input_node.node_id == "human_input_123"
assert human_input_node.tenant_id == "tenant-123"
assert human_input_node.app_id == "app-456"
assert isinstance(human_input_node.node_data, HumanInputNodeData)
assert human_input_node.node_data.title == "User Confirmation"
def test_node_type_and_version(self, human_input_node):
"""Test node type and version."""
assert human_input_node.type_.value == "human_input"
assert human_input_node.version() == "1"
def test_node_properties(self, human_input_node):
"""Test node properties access."""
assert human_input_node.title == "User Confirmation"
assert human_input_node.description == "Please confirm the action"
assert human_input_node.error_strategy is None
assert human_input_node.retry_config.retry_enabled is False
@patch("uuid.uuid4")
def test_node_run_success(self, mock_uuid, human_input_node):
"""Test successful node execution."""
# Setup mocks
mock_form_id = uuid.UUID("12345678-1234-5678-9abc-123456789012")
mock_token = uuid.UUID("87654321-4321-8765-cba9-876543210987")
mock_uuid.side_effect = [mock_form_id, mock_token]
# Execute the node
result = human_input_node._run()
# Verify result
assert result.status == WorkflowNodeExecutionStatus.RUNNING
assert result.metadata["suspended"] is True
assert result.metadata["form_id"] == str(mock_form_id)
assert result.metadata["web_app_form_token"] == str(mock_token).replace("-", "")
# Verify event data in metadata
human_input_event = result.metadata["human_input_event"]
assert human_input_event["form_id"] == str(mock_form_id)
assert human_input_event["node_id"] == "human_input_123"
assert human_input_event["form_content"] == "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}"
assert len(human_input_event["inputs"]) == 1
suspended_event = result.metadata["suspended_event"]
assert suspended_event["suspended_at_node_ids"] == ["human_input_123"]
def test_node_run_without_webapp_delivery(self, human_input_node):
"""Test node execution without webapp delivery method."""
# Modify node data to disable webapp delivery
human_input_node.node_data.delivery_methods[0].enabled = False
result = human_input_node._run()
# Should still work, but without web app token
assert result.status == WorkflowNodeExecutionStatus.RUNNING
assert result.metadata["web_app_form_token"] is None
def test_resume_from_human_input_success(self, human_input_node):
"""Test successful resume from human input."""
form_submission_data = {"inputs": {"confirmation": "yes"}, "action": "confirm"}
result = human_input_node.resume_from_human_input(form_submission_data)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["confirmation"] == "yes"
assert result.outputs["_action"] == "confirm"
assert result.metadata["form_submitted"] is True
assert result.metadata["submitted_action"] == "confirm"
def test_resume_from_human_input_partial_inputs(self, human_input_node):
"""Test resume with partial inputs."""
form_submission_data = {
"inputs": {}, # Empty inputs
"action": "cancel",
}
result = human_input_node.resume_from_human_input(form_submission_data)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "confirmation" not in result.outputs # Field not provided
assert result.outputs["_action"] == "cancel"
def test_resume_from_human_input_missing_data(self, human_input_node):
"""Test resume with missing submission data."""
form_submission_data = {} # Missing required fields
result = human_input_node.resume_from_human_input(form_submission_data)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["_action"] == "" # Default empty action
def test_get_default_config(self):
"""Test getting default configuration."""
config = HumanInputNode.get_default_config()
assert config["type"] == "human_input"
assert "config" in config
config_data = config["config"]
assert len(config_data["delivery_methods"]) == 1
assert config_data["delivery_methods"][0]["type"] == "webapp"
assert config_data["delivery_methods"][0]["enabled"] is True
assert config_data["form_content"] == "# Human Input\n\nPlease provide your input:\n\n{{#$output.input#}}"
assert len(config_data["inputs"]) == 1
assert config_data["inputs"][0]["output_variable_name"] == "input"
assert len(config_data["user_actions"]) == 1
assert config_data["user_actions"][0]["id"] == "submit"
assert config_data["timeout"] == 24
assert config_data["timeout_unit"] == "hour"
def test_process_form_content(self, human_input_node):
"""Test form content processing."""
# This is a placeholder test since the actual variable substitution
# logic is marked as TODO in the implementation
processed_content = human_input_node._process_form_content()
# For now, should return the raw content
expected_content = "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}"
assert processed_content == expected_content
def test_extract_variable_selector_mapping(self):
"""Test variable selector extraction."""
graph_config = {}
node_data = {
"form_content": "Hello {{#node_123.output#}}",
"inputs": [
{
"type": "text-input",
"output_variable_name": "test",
"placeholder": {"type": "variable", "selector": ["node_456", "var_name"]},
}
],
}
# This is a placeholder test since the actual extraction logic
# is marked as TODO in the implementation
mapping = HumanInputNode._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id="test_node", node_data=node_data
)
# For now, should return empty dict
assert mapping == {}
class TestHumanInputNodeValidation:
"""Test validation scenarios for HumanInputNode."""
def test_node_with_invalid_config(self):
"""Test node creation with invalid configuration."""
invalid_config = {
"id": "test_node",
"data": {
"title": "Test",
"delivery_methods": [
{
"type": "invalid_type", # Invalid delivery method type
"enabled": True,
"config": {},
}
],
},
}
mock_params = Mock()
mock_params.tenant_id = "tenant-123"
mock_params.app_id = "app-456"
mock_params.user_id = "user-789"
mock_params.user_from = "web"
mock_params.invoke_from = "web_app"
mock_params.call_depth = 0
with pytest.raises(ValueError):
HumanInputNode(
id="node_123",
config=invalid_config,
graph_init_params=mock_params,
graph=Mock(),
graph_runtime_state=Mock(),
)
def test_node_with_missing_node_id(self):
"""Test node creation with missing node ID in config."""
invalid_config = {
# Missing "id" field
"data": {"title": "Test"}
}
mock_params = Mock()
mock_params.tenant_id = "tenant-123"
mock_params.app_id = "app-456"
mock_params.user_id = "user-789"
mock_params.user_from = "web"
mock_params.invoke_from = "web_app"
mock_params.call_depth = 0
with pytest.raises(ValueError, match="Node ID is required"):
HumanInputNode(
id="node_123",
config=invalid_config,
graph_init_params=mock_params,
graph=Mock(),
graph_runtime_state=Mock(),
)

View File

@ -1 +1 @@
# Unit tests for human input library
# Treat this directory as a package so support modules can be imported relatively.

View File

@ -0,0 +1,248 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Optional
from core.workflow.nodes.human_input.entities import FormInput, TimeoutUnit
# Exceptions
class HumanInputError(Exception):
error_code: str = "unknown"
def __init__(self, message: str = "", error_code: str | None = None):
super().__init__(message)
self.message = message or self.__class__.__name__
if error_code:
self.error_code = error_code
class FormNotFoundError(HumanInputError):
error_code = "form_not_found"
class FormExpiredError(HumanInputError):
error_code = "human_input_form_expired"
class FormAlreadySubmittedError(HumanInputError):
error_code = "human_input_form_submitted"
class InvalidFormDataError(HumanInputError):
error_code = "invalid_form_data"
# Models
@dataclass
class HumanInputForm:
form_id: str
workflow_run_id: str
node_id: str
tenant_id: str
app_id: str | None
form_content: str
inputs: list[FormInput]
user_actions: list[dict[str, Any]]
timeout: int
timeout_unit: TimeoutUnit
web_app_form_token: str | None = None
created_at: datetime = field(default_factory=datetime.utcnow)
expires_at: datetime | None = None
submitted_at: datetime | None = None
submitted_data: dict[str, Any] | None = None
submitted_action: str | None = None
def __post_init__(self) -> None:
if self.expires_at is None:
self.calculate_expiration()
@property
def is_expired(self) -> bool:
return self.expires_at is not None and datetime.utcnow() > self.expires_at
@property
def is_submitted(self) -> bool:
return self.submitted_at is not None
def mark_submitted(self, inputs: dict[str, Any], action: str) -> None:
self.submitted_data = inputs
self.submitted_action = action
self.submitted_at = datetime.utcnow()
def submit(self, inputs: dict[str, Any], action: str) -> None:
self.mark_submitted(inputs, action)
def calculate_expiration(self) -> None:
start = self.created_at
if self.timeout_unit == TimeoutUnit.HOUR:
self.expires_at = start + timedelta(hours=self.timeout)
elif self.timeout_unit == TimeoutUnit.DAY:
self.expires_at = start + timedelta(days=self.timeout)
else:
raise ValueError(f"Unsupported timeout unit {self.timeout_unit}")
def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]:
inputs_response = [
{
"type": form_input.type.name.lower().replace("_", "-"),
"output_variable_name": form_input.output_variable_name,
}
for form_input in self.inputs
]
response = {
"form_content": self.form_content,
"inputs": inputs_response,
"user_actions": self.user_actions,
}
if include_site_info:
response["site"] = {"app_id": self.app_id, "title": "Workflow Form"}
return response
@dataclass
class FormSubmissionData:
form_id: str
inputs: dict[str, Any]
action: str
submitted_at: datetime = field(default_factory=datetime.utcnow)
@classmethod
def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore
return cls(form_id=form_id, inputs=request.inputs, action=request.action)
@dataclass
class FormSubmissionRequest:
inputs: dict[str, Any]
action: str
# Repository
class InMemoryFormRepository:
"""
Simple in-memory repository used by unit tests.
"""
def __init__(self):
self._forms: dict[str, HumanInputForm] = {}
@property
def forms(self) -> dict[str, HumanInputForm]:
return self._forms
def save(self, form: HumanInputForm) -> None:
self._forms[form.form_id] = form
def get_by_id(self, form_id: str) -> Optional[HumanInputForm]:
return self._forms.get(form_id)
def get_by_token(self, token: str) -> Optional[HumanInputForm]:
for form in self._forms.values():
if form.web_app_form_token == token:
return form
return None
def delete(self, form_id: str) -> None:
self._forms.pop(form_id, None)
# Service
class FormService:
"""Service layer for managing human input forms in tests."""
def __init__(self, repository: InMemoryFormRepository):
self.repository = repository
def create_form(
self,
*,
form_id: str,
workflow_run_id: str,
node_id: str,
tenant_id: str,
app_id: str | None,
form_content: str,
inputs,
user_actions,
timeout: int,
timeout_unit: TimeoutUnit,
web_app_form_token: str | None = None,
) -> HumanInputForm:
form = HumanInputForm(
form_id=form_id,
workflow_run_id=workflow_run_id,
node_id=node_id,
tenant_id=tenant_id,
app_id=app_id,
form_content=form_content,
inputs=list(inputs),
user_actions=[{"id": action.id, "title": action.title} for action in user_actions],
timeout=timeout,
timeout_unit=timeout_unit,
web_app_form_token=web_app_form_token,
)
form.calculate_expiration()
self.repository.save(form)
return form
def get_form_by_id(self, form_id: str) -> HumanInputForm:
form = self.repository.get_by_id(form_id)
if form is None:
raise FormNotFoundError()
return form
def get_form_by_token(self, token: str) -> HumanInputForm:
form = self.repository.get_by_token(token)
if form is None:
raise FormNotFoundError()
return form
def get_form_definition(self, form_id: str, *, is_token: bool) -> dict:
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
if form.is_expired:
raise FormExpiredError()
if form.is_submitted:
raise FormAlreadySubmittedError()
definition = {
"form_content": form.form_content,
"inputs": form.inputs,
"user_actions": form.user_actions,
}
if is_token:
definition["site"] = {"title": "Workflow Form"}
return definition
def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None:
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
if form.is_expired:
raise FormExpiredError()
if form.is_submitted:
raise FormAlreadySubmittedError()
self._validate_submission(form=form, submission_data=submission_data)
form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action)
self.repository.save(form)
def cleanup_expired_forms(self) -> int:
expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired]
for form_id in expired_ids:
self.repository.delete(form_id)
return len(expired_ids)
def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None:
defined_actions = {action["id"] for action in form.user_actions}
if submission_data.action not in defined_actions:
raise InvalidFormDataError(f"Invalid action: {submission_data.action}")
missing_inputs = []
for form_input in form.inputs:
if form_input.output_variable_name not in submission_data.inputs:
missing_inputs.append(form_input.output_variable_name)
if missing_inputs:
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
# Extra inputs are allowed; no further validation required.

View File

@ -5,15 +5,6 @@ Unit tests for FormService.
from datetime import datetime, timedelta
import pytest
from libs._human_input.exceptions import (
FormAlreadySubmittedError,
FormExpiredError,
FormNotFoundError,
InvalidFormDataError,
)
from libs._human_input.form_service import FormService
from libs._human_input.models import FormSubmissionData
from libs._human_input.repository import InMemoryFormRepository
from core.workflow.nodes.human_input.entities import (
FormInput,
@ -22,6 +13,16 @@ from core.workflow.nodes.human_input.entities import (
UserAction,
)
from .support import (
FormAlreadySubmittedError,
FormExpiredError,
FormNotFoundError,
FormService,
FormSubmissionData,
InMemoryFormRepository,
InvalidFormDataError,
)
class TestFormService:
"""Test FormService functionality."""

View File

@ -5,16 +5,16 @@ Unit tests for human input form models.
from datetime import datetime, timedelta
import pytest
from libs._human_input.models import FormSubmissionData, HumanInputForm
from core.workflow.nodes.human_input.entities import (
FormInput,
FormInputType,
FormSubmissionRequest,
TimeoutUnit,
UserAction,
)
from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm
class TestHumanInputForm:
"""Test HumanInputForm model."""

View File

@ -0,0 +1,205 @@
import dataclasses
from datetime import datetime
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord
from core.workflow.nodes.human_input.entities import FormDefinition, TimeoutUnit, UserAction
from models.account import Account
from models.human_input import RecipientType
from services.human_input_service import FormSubmittedError, HumanInputService
@pytest.fixture
def mock_session_factory():
session = MagicMock()
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = None
factory = MagicMock()
factory.return_value = session_cm
return factory, session
@pytest.fixture
def sample_form_record():
return HumanInputFormRecord(
form_id="form-id",
workflow_run_id="workflow-run-id",
node_id="node-id",
tenant_id="tenant-id",
definition=FormDefinition(
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
rendered_content="<p>hello</p>",
timeout=1,
timeout_unit=TimeoutUnit.HOUR,
),
rendered_content="<p>hello</p>",
expiration_time=datetime(2024, 1, 1),
selected_action_id=None,
submitted_data=None,
submitted_at=None,
submission_user_id=None,
submission_end_user_id=None,
completed_by_recipient_id=None,
recipient_id="recipient-id",
recipient_type=RecipientType.WEBAPP,
access_token="token",
)
def test_enqueue_resume_dispatches_task(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
trigger_log = MagicMock()
trigger_log.id = "trigger-log-id"
trigger_log.queue_name = "workflow_queue"
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = trigger_log
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "workflow_queue"
payload = call_kwargs["kwargs"]["task_data_dict"]
assert payload["workflow_trigger_log_id"] == "trigger-log-id"
assert payload["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_no_trigger_log(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_not_called()
def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
app = MagicMock()
app.mode = "advanced-chat"
session.get.side_effect = [workflow_run, app]
resume_task = mocker.patch("services.human_input_service.resume_chatflow_execution")
service._enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "chatflow_execute"
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormReadRepository)
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
form = service.get_form_definition_by_id("form-id")
repo.get_by_form_id_and_recipient_type.assert_called_once_with(
form_id="form-id",
recipient_type=RecipientType.WEBAPP,
)
assert form is not None
assert form.get_definition() == sample_form_record.definition
def test_get_form_definition_by_id_raises_on_submitted(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime(2024, 1, 1))
repo = MagicMock(spec=HumanInputFormReadRepository)
repo.get_by_form_id_and_recipient_type.return_value = submitted_record
service = HumanInputService(session_factory, form_repository=repo)
with pytest.raises(FormSubmittedError):
service.get_form_definition_by_id("form-id")
def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormReadRepository)
repo.get_by_token.return_value = sample_form_record
repo.mark_submitted.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "_enqueue_resume")
service.submit_form_by_token(
recipient_type=RecipientType.WEBAPP,
form_token="token",
selected_action_id="approve",
form_data={"field": "value"},
submission_end_user_id="end-user-id",
)
repo.get_by_token.assert_called_once_with("token")
repo.mark_submitted.assert_called_once()
call_kwargs = repo.mark_submitted.call_args.kwargs
assert call_kwargs["form_id"] == sample_form_record.form_id
assert call_kwargs["recipient_id"] == sample_form_record.recipient_id
assert call_kwargs["selected_action_id"] == "approve"
assert call_kwargs["form_data"] == {"field": "value"}
assert call_kwargs["submission_end_user_id"] == "end-user-id"
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
def test_submit_form_by_id_passes_account(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormReadRepository)
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
repo.mark_submitted.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "_enqueue_resume")
account = MagicMock(spec=Account)
account.id = "account-id"
service.submit_form_by_id(
form_id="form-id",
selected_action_id="approve",
form_data={"x": 1},
user=account,
)
repo.get_by_form_id_and_recipient_type.assert_called_once()
repo.mark_submitted.assert_called_once()
assert repo.mark_submitted.call_args.kwargs["submission_user_id"] == "account-id"
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)

View File

@ -35,7 +35,6 @@ class TestDataFactory:
app_id: str = "app-789",
workflow_id: str = "workflow-101",
status: str | WorkflowExecutionStatus = "paused",
pause_id: str | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowRun object."""
@ -45,7 +44,6 @@ class TestDataFactory:
mock_run.app_id = app_id
mock_run.workflow_id = workflow_id
mock_run.status = status
mock_run.pause_id = pause_id
for key, value in kwargs.items():
setattr(mock_run, key, value)

View File

@ -161,292 +161,3 @@ class TestWorkflowService:
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()
class TestWorkflowServiceHumanInputValidation:
@pytest.fixture
def workflow_service(self):
# Mock sessionmaker to avoid database dependency
mock_session_maker = MagicMock()
return WorkflowService(mock_session_maker)
def test_validate_graph_structure_valid_human_input(self, workflow_service):
"""Test validation of valid HumanInput node data."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [
{
"type": "text-input",
"output_variable_name": "user_input",
"placeholder": {"type": "constant", "value": "Enter text here"},
}
],
"user_actions": [{"id": "submit", "title": "Submit", "button_style": "primary"}],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_graph_structure_empty_graph(self, workflow_service):
"""Test validation of empty graph."""
graph = {}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_graph_structure_no_nodes(self, workflow_service):
"""Test validation of graph with no nodes."""
graph = {"nodes": []}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_graph_structure_non_human_input_node(self, workflow_service):
"""Test validation ignores non-HumanInput nodes."""
graph = {"nodes": [{"id": "node-1", "data": {"type": "start", "title": "Start"}}]}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_delivery_method_type(self, workflow_service):
"""Test validation fails with invalid delivery method type."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "invalid_type", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_form_input_type(self, workflow_service):
"""Test validation fails with invalid form input type."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [{"type": "invalid-input-type", "output_variable_name": "user_input"}],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_missing_required_fields(self, workflow_service):
"""Test validation fails with missing required fields."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
# Missing required fields like title
"delivery_methods": [],
"form_content": "",
"inputs": [],
"user_actions": [],
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_timeout_unit(self, workflow_service):
"""Test validation fails with invalid timeout unit."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "invalid_unit",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_button_style(self, workflow_service):
"""Test validation fails with invalid button style."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [],
"user_actions": [{"id": "submit", "title": "Submit", "button_style": "invalid_style"}],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_email_delivery_config(self, workflow_service):
"""Test validation of HumanInput node with email delivery configuration."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [
{
"type": "email",
"enabled": True,
"config": {
"recipients": {
"whole_workspace": False,
"items": [{"type": "external", "email": "user@example.com"}],
},
"subject": "Input Required",
"body": "Please provide your input",
},
}
],
"form_content": "Please provide your input",
"inputs": [
{
"type": "paragraph",
"output_variable_name": "feedback",
"placeholder": {"type": "variable", "selector": ["node", "output"]},
}
],
"user_actions": [
{"id": "approve", "title": "Approve", "button_style": "accent"},
{"id": "reject", "title": "Reject", "button_style": "ghost"},
],
"timeout": 7,
"timeout_unit": "day",
},
}
]
}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_email_recipient(self, workflow_service):
"""Test validation fails with invalid email recipient."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [
{
"type": "email",
"enabled": True,
"config": {
"recipients": {
"whole_workspace": False,
"items": [{"type": "invalid_recipient_type", "email": "user@example.com"}],
},
"subject": "Input Required",
"body": "Please provide your input",
},
}
],
"form_content": "Please provide your input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_multiple_nodes_mixed_valid_invalid(self, workflow_service):
"""Test validation with multiple nodes where some are valid and some invalid."""
graph = {
"nodes": [
{"id": "node-1", "data": {"type": "start", "title": "Start"}},
{
"id": "node-2",
"data": {
"type": "human_input",
"title": "Valid Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Valid input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
},
{
"id": "node-3",
"data": {
"type": "human_input",
"title": "Invalid Human Input",
"delivery_methods": [{"type": "invalid_method", "enabled": True}],
"form_content": "Invalid input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
},
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)