mirror of
https://github.com/langgenius/dify.git
synced 2026-02-22 19:15:47 +08:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from models.enums import MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline:
|
||||
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__(
|
||||
pipeline_module.AdvancedChatAppGenerateTaskPipeline
|
||||
)
|
||||
pipeline._workflow_run_id = "run-1"
|
||||
pipeline._message_id = "message-1"
|
||||
pipeline._workflow_tenant_id = "tenant-1"
|
||||
return pipeline
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
|
||||
|
||||
captured_session: dict[str, mock.Mock] = {}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
session.scalar.return_value = None
|
||||
captured_session["session"] = session
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
session = captured_session["session"]
|
||||
session.add.assert_called_once()
|
||||
content = session.add.call_args.args[0]
|
||||
assert isinstance(content, HumanInputContent)
|
||||
assert content.workflow_run_id == "run-1"
|
||||
assert content.message_id == "message-1"
|
||||
assert content.form_id == "form-1"
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None)
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
called["value"] = True
|
||||
session = mock.Mock()
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
assert called["value"] is False
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
|
||||
|
||||
captured_session: dict[str, mock.Mock] = {}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
session.scalar.return_value = HumanInputContent(
|
||||
workflow_run_id="run-1",
|
||||
message_id="message-1",
|
||||
form_id="form-1",
|
||||
)
|
||||
captured_session["session"] = session
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
session = captured_session["session"]
|
||||
session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = []
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(
|
||||
return_value=SimpleNamespace(
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
),
|
||||
)
|
||||
pipeline._save_message = mock.Mock()
|
||||
message = SimpleNamespace(status=MessageStatus.NORMAL)
|
||||
pipeline._get_message = mock.Mock(return_value=message)
|
||||
pipeline._persist_human_input_extra_content = mock.Mock()
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._message_saved_on_pause = False
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
form_token="token-1",
|
||||
resolved_default_values={},
|
||||
)
|
||||
event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"])
|
||||
|
||||
list(pipeline._handle_workflow_paused_event(event))
|
||||
|
||||
pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1")
|
||||
assert message.status == MessageStatus.PAUSED
|
||||
|
||||
|
||||
def test_resume_appends_chunks_to_paused_answer() -> None:
|
||||
app_config = SimpleNamespace(app_id="app-1", tenant_id="tenant-1", sensitive_word_avoidance=None)
|
||||
application_generate_entity = SimpleNamespace(
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
workflow_run_id="run-1",
|
||||
query="hello",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
inputs={},
|
||||
task_id="task-1",
|
||||
)
|
||||
queue_manager = SimpleNamespace(graph_runtime_state=None)
|
||||
conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat")
|
||||
message = SimpleNamespace(
|
||||
id="message-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
query="hello",
|
||||
answer="before",
|
||||
status=MessageStatus.PAUSED,
|
||||
)
|
||||
user = EndUser()
|
||||
user.id = "user-1"
|
||||
user.session_id = "session-1"
|
||||
workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={})
|
||||
|
||||
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=True,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=SimpleNamespace(),
|
||||
)
|
||||
|
||||
pipeline._get_message = mock.Mock(return_value=message)
|
||||
pipeline._recorded_files = []
|
||||
|
||||
list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after")))
|
||||
pipeline._save_message(session=mock.Mock())
|
||||
|
||||
assert message.answer == "beforeafter"
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
@ -0,0 +1,87 @@
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _build_converter():
|
||||
system_variables = SystemVariable(
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
app_entity = SimpleNamespace(
|
||||
task_id="task-1",
|
||||
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
files=[],
|
||||
inputs={},
|
||||
workflow_execution_id="run-1",
|
||||
call_depth=0,
|
||||
)
|
||||
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=app_entity,
|
||||
user=account,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_form_filled_stream_response_contains_rendered_content():
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
queue_event = QueueHumanInputFormFilledEvent(
|
||||
node_execution_id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
node_title="Human Input",
|
||||
rendered_content="# Title\nvalue",
|
||||
action_id="Approve",
|
||||
action_text="Approve",
|
||||
)
|
||||
|
||||
resp = converter.human_input_form_filled_to_stream_response(event=queue_event, task_id="task-1")
|
||||
|
||||
assert resp.workflow_run_id == "run-1"
|
||||
assert resp.data.node_id == "node-1"
|
||||
assert resp.data.node_title == "Human Input"
|
||||
assert resp.data.rendered_content.startswith("# Title")
|
||||
assert resp.data.action_id == "Approve"
|
||||
|
||||
|
||||
def test_human_input_form_timeout_stream_response_contains_timeout_metadata():
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
queue_event = QueueHumanInputFormTimeoutEvent(
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
node_title="Human Input",
|
||||
expiration_time=datetime(2025, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
resp = converter.human_input_form_timeout_to_stream_response(event=queue_event, task_id="task-1")
|
||||
|
||||
assert resp.workflow_run_id == "run-1"
|
||||
assert resp.data.node_id == "node-1"
|
||||
assert resp.data.node_title == "Human Input"
|
||||
assert resp.data.expiration_time == 1735689600
|
||||
@ -0,0 +1,56 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _build_converter() -> WorkflowResponseConverter:
|
||||
"""Construct a minimal WorkflowResponseConverter for testing."""
|
||||
system_variables = SystemVariable(
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
app_entity = SimpleNamespace(
|
||||
task_id="task-1",
|
||||
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
files=[],
|
||||
inputs={},
|
||||
workflow_execution_id="run-1",
|
||||
call_depth=0,
|
||||
)
|
||||
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=app_entity,
|
||||
user=account,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_start_stream_response_carries_resumption_reason():
|
||||
converter = _build_converter()
|
||||
resp = converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.RESUMPTION,
|
||||
)
|
||||
assert resp.data.reason is WorkflowStartReason.RESUMPTION
|
||||
|
||||
|
||||
def test_workflow_start_stream_response_carries_initial_reason():
|
||||
converter = _build_converter()
|
||||
resp = converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
assert resp.data.reason is WorkflowStartReason.INITIAL
|
||||
@ -23,6 +23,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -124,7 +125,12 @@ class TestWorkflowResponseConverter:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -160,7 +166,12 @@ class TestWorkflowResponseConverter:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -191,7 +202,12 @@ class TestWorkflowResponseConverter:
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -225,7 +241,12 @@ class TestWorkflowResponseConverter:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -261,7 +282,12 @@ class TestWorkflowResponseConverter:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -400,6 +426,7 @@ class TestWorkflowResponseConverterServiceApiTruncation:
|
||||
task_id="test-task-id",
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
return converter
|
||||
|
||||
|
||||
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.task_pipeline import message_cycle_manager
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
def _make_app_config() -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_id="workflow-id",
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity:
|
||||
return AdvancedChatAppGenerateEntity(
|
||||
task_id="task-id",
|
||||
app_config=app_config,
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
parent_message_id=None,
|
||||
user_id="user-id",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
workflow_run_id="workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_db_session(monkeypatch):
|
||||
session = MagicMock()
|
||||
|
||||
def refresh_side_effect(obj):
|
||||
if isinstance(obj, Conversation) and obj.id is None:
|
||||
obj.id = "generated-conversation-id"
|
||||
if isinstance(obj, Message) and obj.id is None:
|
||||
obj.id = "generated-message-id"
|
||||
|
||||
session.refresh.side_effect = refresh_side_effect
|
||||
session.add.return_value = None
|
||||
session.commit.return_value = None
|
||||
|
||||
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
|
||||
return session
|
||||
|
||||
|
||||
def test_init_generate_records_sets_conversation_metadata():
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
|
||||
|
||||
def test_init_generate_records_marks_existing_conversation():
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
|
||||
existing_conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
override_model_configs=None,
|
||||
model_id=None,
|
||||
mode=app_config.app_mode.value,
|
||||
name="existing",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
from_source="api",
|
||||
from_end_user_id="user-id",
|
||||
from_account_id=None,
|
||||
)
|
||||
existing_conversation.id = "existing-conversation-id"
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation)
|
||||
|
||||
assert entity.conversation_id == "existing-conversation-id"
|
||||
assert conversation is existing_conversation
|
||||
assert entity.is_new_conversation is False
|
||||
|
||||
|
||||
def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
entity.conversation_id = "existing-conversation-id"
|
||||
entity.is_new_conversation = True
|
||||
entity.extras = {"auto_generate_conversation_name": True}
|
||||
|
||||
captured = {}
|
||||
|
||||
class DummyThread:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
def fake_thread(**kwargs):
|
||||
thread = DummyThread(**kwargs)
|
||||
captured["thread"] = thread
|
||||
return thread
|
||||
|
||||
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
|
||||
|
||||
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
|
||||
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")
|
||||
|
||||
assert thread is captured["thread"]
|
||||
assert thread.started is True
|
||||
assert entity.is_new_conversation is False
|
||||
@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AppAdditionalFeatures,
|
||||
EasyUIBasedAppConfig,
|
||||
EasyUIBasedAppModelConfigFrom,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
class DummyModelConf:
|
||||
def __init__(self, provider: str = "mock-provider", model: str = "mock-model") -> None:
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
|
||||
|
||||
class DummyCompletionGenerateEntity:
|
||||
__slots__ = ("app_config", "invoke_from", "user_id", "query", "inputs", "files", "model_conf")
|
||||
app_config: EasyUIBasedAppConfig
|
||||
invoke_from: InvokeFrom
|
||||
user_id: str
|
||||
query: str
|
||||
inputs: dict
|
||||
files: list
|
||||
model_conf: DummyModelConf
|
||||
|
||||
def __init__(self, app_config: EasyUIBasedAppConfig) -> None:
|
||||
self.app_config = app_config
|
||||
self.invoke_from = InvokeFrom.WEB_APP
|
||||
self.user_id = "user-id"
|
||||
self.query = "hello"
|
||||
self.inputs = {}
|
||||
self.files = []
|
||||
self.model_conf = DummyModelConf()
|
||||
|
||||
|
||||
def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig:
|
||||
return EasyUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
|
||||
app_model_config_id="model-config-id",
|
||||
app_model_config_dict={},
|
||||
model=ModelConfigEntity(provider="mock-provider", model="mock-model", mode="chat"),
|
||||
prompt_template=PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="Hello",
|
||||
),
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_chat_generate_entity(app_config: EasyUIBasedAppConfig) -> ChatAppGenerateEntity:
|
||||
return ChatAppGenerateEntity.model_construct(
|
||||
task_id="task-id",
|
||||
app_config=app_config,
|
||||
model_conf=DummyModelConf(),
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
parent_message_id=None,
|
||||
user_id="user-id",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
call_depth=0,
|
||||
trace_manager=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_db_session(monkeypatch):
|
||||
session = MagicMock()
|
||||
|
||||
def refresh_side_effect(obj):
|
||||
if isinstance(obj, Conversation) and obj.id is None:
|
||||
obj.id = "generated-conversation-id"
|
||||
if isinstance(obj, Message) and obj.id is None:
|
||||
obj.id = "generated-message-id"
|
||||
|
||||
session.refresh.side_effect = refresh_side_effect
|
||||
session.add.return_value = None
|
||||
session.commit.return_value = None
|
||||
|
||||
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
|
||||
return session
|
||||
|
||||
|
||||
def test_init_generate_records_skips_conversation_fields_for_non_conversation_entity():
|
||||
app_config = _make_app_config(AppMode.COMPLETION)
|
||||
entity = DummyCompletionGenerateEntity(app_config=app_config)
|
||||
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
conversation, message = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
assert message.id == "generated-message-id"
|
||||
assert hasattr(entity, "conversation_id") is False
|
||||
assert hasattr(entity, "is_new_conversation") is False
|
||||
|
||||
|
||||
def test_init_generate_records_sets_conversation_fields_for_chat_entity():
|
||||
app_config = _make_app_config(AppMode.CHAT)
|
||||
entity = _make_chat_generate_entity(app_config)
|
||||
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
287
api/tests/unit_tests/core/app/apps/test_pause_resume.py
Normal file
287
api/tests/unit_tests/core/app/apps/test_pause_resume.py
Normal file
@ -0,0 +1,287 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
API_DIR = str(Path(__file__).resolve().parents[5])
|
||||
if API_DIR not in sys.path:
|
||||
sys.path.insert(0, API_DIR)
|
||||
|
||||
import core.workflow.nodes.human_input.entities # noqa: F401
|
||||
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
if "core.ops.ops_trace_manager" not in sys.modules:
|
||||
ops_stub = ModuleType("core.ops.ops_trace_manager")
|
||||
|
||||
class _StubTraceQueueManager:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
ops_stub.TraceQueueManager = _StubTraceQueueManager
|
||||
sys.modules["core.ops.ops_trace_manager"] = ops_stub
|
||||
|
||||
|
||||
class _StubToolNodeData(BaseNodeData):
|
||||
pause_on: bool = False
|
||||
|
||||
|
||||
class _StubToolNode(Node[_StubToolNodeData]):
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def init_node_data(self, data):
|
||||
self._node_data = _StubToolNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self):
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self):
|
||||
if self.node_data.pause_on:
|
||||
yield PauseRequestedEvent(reason=SchedulingPause(message="test pause"))
|
||||
return
|
||||
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"value": f"{self.id}-done"},
|
||||
)
|
||||
yield self._convert_node_run_result_to_graph_node_event(result)
|
||||
|
||||
|
||||
def _patch_tool_node(mocker):
|
||||
original_create_node = DifyNodeFactory.create_node
|
||||
|
||||
def _patched_create_node(self, node_config: dict[str, object]) -> Node:
|
||||
node_data = node_config.get("data", {})
|
||||
if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value:
|
||||
return _StubToolNode(
|
||||
id=str(node_config["id"]),
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
return original_create_node(self, node_config)
|
||||
|
||||
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
|
||||
|
||||
|
||||
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
|
||||
node_data = data.model_dump()
|
||||
node_data["type"] = node_type.value
|
||||
return node_data
|
||||
|
||||
|
||||
def _build_graph_config(*, pause_on: str | None) -> dict[str, object]:
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
tool_data_a = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_a")
|
||||
tool_data_b = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_b")
|
||||
tool_data_c = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_c")
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[OutputVariableEntity(variable="result", value_selector=["tool_c", "value"])],
|
||||
desc=None,
|
||||
)
|
||||
|
||||
nodes = [
|
||||
{"id": "start", "data": _node_data(NodeType.START, start_data)},
|
||||
{"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)},
|
||||
{"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)},
|
||||
{"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)},
|
||||
{"id": "end", "data": _node_data(NodeType.END, end_data)},
|
||||
]
|
||||
edges = [
|
||||
{"source": "start", "target": "tool_a"},
|
||||
{"source": "tool_a", "target": "tool_b"},
|
||||
{"source": "tool_b", "target": "tool_c"},
|
||||
{"source": "tool_c", "target": "end"},
|
||||
]
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph:
|
||||
graph_config = _build_graph_config(pause_on=pause_on)
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="service-api",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
return Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.system_variables.workflow_execution_id = run_id
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
|
||||
command_channel = InMemoryChannel()
|
||||
graph = _build_graph(runtime_state, pause_on=pause_on)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
events: list[GraphEngineEvent] = []
|
||||
for event in engine.run():
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
|
||||
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
|
||||
|
||||
|
||||
def test_workflow_app_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("paused-run")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = wf_app_gen_module.WorkflowAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
graph_runtime_state=resumed_state,
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
|
||||
|
||||
def test_advanced_chat_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("adv-baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("adv-paused")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = adv_app_gen_module.AdvancedChatAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv"),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
graph_runtime_state=resumed_state,
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
|
||||
|
||||
def test_resume_emits_resumption_start_reason(mocker) -> None:
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
paused_state = _build_runtime_state("resume-reason")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
initial_start = next(event for event in paused_events if isinstance(event, GraphRunStartedEvent))
|
||||
assert initial_start.reason == WorkflowStartReason.INITIAL
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(paused_state.dumps())
|
||||
resumed_events = _run_with_optional_pause(resumed_state, pause_on=None)
|
||||
resume_start = next(event for event in resumed_events if isinstance(event, GraphRunStartedEvent))
|
||||
assert resume_start.reason == WorkflowStartReason.RESUMPTION
|
||||
80
api/tests/unit_tests/core/app/apps/test_streaming_utils.py
Normal file
80
api/tests/unit_tests/core/app/apps/test_streaming_utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class FakeSubscription:
|
||||
def __init__(self, message_queue: queue.Queue[bytes], state: dict[str, bool]) -> None:
|
||||
self._queue = message_queue
|
||||
self._state = state
|
||||
self._closed = False
|
||||
|
||||
def __enter__(self):
|
||||
self._state["subscribed"] = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
if self._closed:
|
||||
return None
|
||||
try:
|
||||
if timeout is None:
|
||||
return self._queue.get()
|
||||
return self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
|
||||
class FakeTopic:
|
||||
def __init__(self) -> None:
|
||||
self._queue: queue.Queue[bytes] = queue.Queue()
|
||||
self._state = {"subscribed": False}
|
||||
|
||||
def subscribe(self) -> FakeSubscription:
|
||||
return FakeSubscription(self._queue, self._state)
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._queue.put(payload)
|
||||
|
||||
@property
|
||||
def subscribed(self) -> bool:
|
||||
return self._state["subscribed"]
|
||||
|
||||
|
||||
def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
|
||||
topic = FakeTopic()
|
||||
|
||||
def fake_get_response_topic(cls, app_mode, workflow_run_id):
|
||||
return topic
|
||||
|
||||
monkeypatch.setattr(MessageBasedAppGenerator, "get_response_topic", classmethod(fake_get_response_topic))
|
||||
|
||||
def on_subscribe() -> None:
|
||||
assert topic.subscribed is True
|
||||
event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
|
||||
topic.publish(json.dumps(event).encode())
|
||||
|
||||
generator = MessageBasedAppGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
"workflow-run-id",
|
||||
idle_timeout=0.5,
|
||||
on_subscribe=on_subscribe,
|
||||
)
|
||||
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
event = next(generator)
|
||||
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
@ -1,3 +1,6 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
|
||||
|
||||
@ -17,3 +20,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
|
||||
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
|
||||
|
||||
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
|
||||
|
||||
def test_resume_delegates_to_generate(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")
|
||||
|
||||
application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger")
|
||||
runtime_state = MagicMock(name="runtime-state")
|
||||
pause_config = MagicMock(name="pause-config")
|
||||
|
||||
result = generator.resume(
|
||||
app_model=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
graph_runtime_state=runtime_state,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=("layer",),
|
||||
pause_state_config=pause_config,
|
||||
variable_loader=MagicMock(),
|
||||
)
|
||||
|
||||
assert result == "ok"
|
||||
mock_generate.assert_called_once()
|
||||
kwargs = mock_generate.call_args.kwargs
|
||||
assert kwargs["graph_runtime_state"] is runtime_state
|
||||
assert kwargs["pause_state_config"] is pause_config
|
||||
assert kwargs["streaming"] is False
|
||||
assert kwargs["invoke_from"] == "debugger"
|
||||
|
||||
|
||||
def test_generate_appends_pause_layer_and_forwards_state(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
mock_queue_manager = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager)
|
||||
|
||||
fake_current_app = MagicMock()
|
||||
fake_current_app._get_current_object.return_value = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
|
||||
return_value="converted",
|
||||
)
|
||||
mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock())
|
||||
|
||||
pause_layer = MagicMock(name="pause-layer")
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.PauseStatePersistenceLayer",
|
||||
return_value=pause_layer,
|
||||
)
|
||||
|
||||
dummy_session = MagicMock()
|
||||
dummy_session.close = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session)
|
||||
|
||||
worker_kwargs: dict[str, object] = {}
|
||||
|
||||
class DummyThread:
|
||||
def __init__(self, target, kwargs):
|
||||
worker_kwargs["target"] = target
|
||||
worker_kwargs["kwargs"] = kwargs
|
||||
|
||||
def start(self):
|
||||
return None
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread)
|
||||
|
||||
app_model = SimpleNamespace(mode="workflow")
|
||||
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf")
|
||||
application_generate_entity = SimpleNamespace(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from="service-api",
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
stream=True,
|
||||
workflow_execution_id="run",
|
||||
)
|
||||
|
||||
graph_runtime_state = MagicMock()
|
||||
|
||||
result = generator._generate(
|
||||
app_model=app_model,
|
||||
workflow=MagicMock(),
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from="service-api",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
graph_engine_layers=("base-layer",),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"),
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer)
|
||||
assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state
|
||||
|
||||
|
||||
def test_resume_path_runs_worker_with_runtime_state(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
runtime_state = MagicMock(name="runtime-state")
|
||||
|
||||
pause_layer = MagicMock(name="pause-layer")
|
||||
mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager)
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="raw-response")
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
|
||||
side_effect=lambda response, invoke_from: response,
|
||||
)
|
||||
|
||||
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
|
||||
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1"
|
||||
)
|
||||
end_user = SimpleNamespace(session_id="end-user-session")
|
||||
app_record = SimpleNamespace(id="app")
|
||||
|
||||
session = MagicMock()
|
||||
session.__enter__.return_value = session
|
||||
session.__exit__.return_value = False
|
||||
session.scalar.side_effect = [workflow, end_user, app_record]
|
||||
mocker.patch("core.app.apps.workflow.app_generator.session_factory", return_value=session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
|
||||
def runner_ctor(**kwargs):
|
||||
assert kwargs["graph_runtime_state"] is runtime_state
|
||||
return runner_instance
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor)
|
||||
|
||||
class ImmediateThread:
|
||||
def __init__(self, target, kwargs):
|
||||
target(**kwargs)
|
||||
|
||||
def start(self):
|
||||
return None
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner")
|
||||
|
||||
app_model = SimpleNamespace(mode="workflow")
|
||||
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow")
|
||||
application_generate_entity = SimpleNamespace(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from="service-api",
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
stream=True,
|
||||
workflow_execution_id="run",
|
||||
trace_manager=MagicMock(),
|
||||
)
|
||||
|
||||
result = generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
graph_runtime_state=runtime_state,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
|
||||
assert result == "raw-response"
|
||||
runner_instance.run.assert_called_once()
|
||||
queue_manager.graph_runtime_state = runtime_state
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
|
||||
|
||||
class _DummyQueueManager:
|
||||
def __init__(self):
|
||||
self.published = []
|
||||
|
||||
def publish(self, event, _from):
|
||||
self.published.append(event)
|
||||
|
||||
|
||||
class _DummyRuntimeState:
|
||||
def get_paused_nodes(self):
|
||||
return ["node-1"]
|
||||
|
||||
|
||||
class _DummyGraphEngine:
|
||||
def __init__(self):
|
||||
self.graph_runtime_state = _DummyRuntimeState()
|
||||
|
||||
|
||||
class _DummyWorkflowEntry:
|
||||
def __init__(self):
|
||||
self.graph_engine = _DummyGraphEngine()
|
||||
|
||||
|
||||
def test_handle_pause_event_enqueues_email_task(monkeypatch: pytest.MonkeyPatch):
|
||||
queue_manager = _DummyQueueManager()
|
||||
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app-id")
|
||||
workflow_entry = _DummyWorkflowEntry()
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-123",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-1",
|
||||
node_title="Review",
|
||||
)
|
||||
event = GraphRunPausedEvent(reasons=[reason], outputs={})
|
||||
|
||||
email_task = MagicMock()
|
||||
monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", email_task)
|
||||
|
||||
runner._handle_event(workflow_entry, event)
|
||||
|
||||
email_task.apply_async.assert_called_once()
|
||||
kwargs = email_task.apply_async.call_args.kwargs["kwargs"]
|
||||
assert kwargs["form_id"] == "form-123"
|
||||
assert kwargs["node_title"] == "Review"
|
||||
|
||||
assert any(isinstance(evt, QueueWorkflowPausedEvent) for evt in queue_manager.published)
|
||||
183
api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
Normal file
183
api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
Normal file
@ -0,0 +1,183 @@
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.common import workflow_response_converter
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.account import Account
|
||||
|
||||
|
||||
class _RecordingWorkflowAppRunner(WorkflowAppRunner):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.published_events = []
|
||||
|
||||
def _publish_event(self, event):
|
||||
self.published_events.append(event)
|
||||
|
||||
|
||||
class _FakeRuntimeState:
|
||||
def get_paused_nodes(self):
|
||||
return ["node-pause-1"]
|
||||
|
||||
|
||||
def _build_runner():
|
||||
app_entity = SimpleNamespace(
|
||||
app_config=SimpleNamespace(app_id="app-id"),
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
single_iteration_run=None,
|
||||
single_loop_run=None,
|
||||
workflow_execution_id="run-id",
|
||||
user_id="user-id",
|
||||
)
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={},
|
||||
tenant_id="tenant-id",
|
||||
environment_variables={},
|
||||
id="workflow-id",
|
||||
)
|
||||
queue_manager = SimpleNamespace(publish=lambda event, pub_from: None)
|
||||
return _RecordingWorkflowAppRunner(
|
||||
application_generate_entity=app_entity,
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="sys-user",
|
||||
root_node_id=None,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=(),
|
||||
graph_runtime_state=None,
|
||||
)
|
||||
|
||||
|
||||
def test_graph_run_paused_event_emits_queue_pause_event():
|
||||
runner = _build_runner()
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-human",
|
||||
node_title="Human Step",
|
||||
form_token="tok",
|
||||
)
|
||||
event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"})
|
||||
workflow_entry = SimpleNamespace(
|
||||
graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()),
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry, event)
|
||||
|
||||
assert len(runner.published_events) == 1
|
||||
queue_event = runner.published_events[0]
|
||||
assert isinstance(queue_event, QueueWorkflowPausedEvent)
|
||||
assert queue_event.reasons == [reason]
|
||||
assert queue_event.outputs == {"foo": "bar"}
|
||||
assert queue_event.paused_nodes == ["node-pause-1"]
|
||||
|
||||
|
||||
def _build_converter():
|
||||
application_generate_entity = SimpleNamespace(
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
|
||||
)
|
||||
system_variables = SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
)
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "account-id"
|
||||
user.name = "Tester"
|
||||
user.email = "tester@example.com"
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.MonkeyPatch):
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
class _FakeSession:
|
||||
def execute(self, _stmt):
|
||||
return [("form-1", expiration_time)]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
|
||||
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="Rendered",
|
||||
inputs=[
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
|
||||
],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
node_id="node-id",
|
||||
node_title="Human Step",
|
||||
form_token="token",
|
||||
)
|
||||
queue_event = QueueWorkflowPausedEvent(
|
||||
reasons=[reason],
|
||||
outputs={"answer": "value"},
|
||||
paused_nodes=["node-id"],
|
||||
)
|
||||
|
||||
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
|
||||
responses = converter.workflow_pause_to_stream_response(
|
||||
event=queue_event,
|
||||
task_id="task",
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
|
||||
pause_resp = responses[-1]
|
||||
assert pause_resp.workflow_run_id == "run-id"
|
||||
assert pause_resp.data.paused_nodes == ["node-id"]
|
||||
assert pause_resp.data.outputs == {}
|
||||
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
|
||||
assert pause_resp.data.reasons[0]["display_in_ui"] is True
|
||||
|
||||
assert isinstance(responses[0], HumanInputRequiredResponse)
|
||||
hi_resp = responses[0]
|
||||
assert hi_resp.data.form_id == "form-1"
|
||||
assert hi_resp.data.node_id == "node-id"
|
||||
assert hi_resp.data.node_title == "Human Step"
|
||||
assert hi_resp.data.inputs[0].output_variable_name == "field"
|
||||
assert hi_resp.data.actions[0].id == "approve"
|
||||
assert hi_resp.data.display_in_ui is True
|
||||
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
|
||||
@ -0,0 +1,96 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueWorkflowStartedEvent
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _build_workflow_app_config() -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
|
||||
|
||||
def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity:
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id="task-id",
|
||||
app_config=_build_workflow_app_config(),
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-id",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_execution_id=run_id,
|
||||
)
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(workflow_execution_id=run_id),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _noop_session():
|
||||
yield MagicMock()
|
||||
|
||||
|
||||
def _build_pipeline(run_id: str) -> WorkflowAppGenerateTaskPipeline:
|
||||
queue_manager = MagicMock(spec=AppQueueManager)
|
||||
queue_manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
queue_manager.graph_runtime_state = _build_runtime_state(run_id)
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-id"
|
||||
workflow.features_dict = {}
|
||||
user = Account(name="user", email="user@example.com")
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=_build_generate_entity(run_id),
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=False,
|
||||
draft_var_saver_factory=MagicMock(),
|
||||
)
|
||||
pipeline._database_session = _noop_session
|
||||
return pipeline
|
||||
|
||||
|
||||
def test_workflow_app_log_saved_only_on_initial_start() -> None:
|
||||
run_id = "run-initial"
|
||||
pipeline = _build_pipeline(run_id)
|
||||
pipeline._save_workflow_app_log = MagicMock()
|
||||
|
||||
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.INITIAL)
|
||||
list(pipeline._handle_workflow_started_event(event))
|
||||
|
||||
pipeline._save_workflow_app_log.assert_called_once()
|
||||
_, kwargs = pipeline._save_workflow_app_log.call_args
|
||||
assert kwargs["workflow_run_id"] == run_id
|
||||
assert pipeline._workflow_execution_id == run_id
|
||||
|
||||
|
||||
def test_workflow_app_log_skipped_on_resumption_start() -> None:
|
||||
run_id = "run-resume"
|
||||
pipeline = _build_pipeline(run_id)
|
||||
pipeline._save_workflow_app_log = MagicMock()
|
||||
|
||||
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.RESUMPTION)
|
||||
list(pipeline._handle_workflow_started_event(event))
|
||||
|
||||
pipeline._save_workflow_app_log.assert_not_called()
|
||||
assert pipeline._workflow_execution_id == run_id
|
||||
@ -0,0 +1,143 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import (
|
||||
WorkflowResumptionContext,
|
||||
_AdvancedChatAppGenerateEntityWrapper,
|
||||
_WorkflowGenerateEntityWrapper,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TraceQueueManagerStub(TraceQueueManager):
|
||||
"""Minimal TraceQueueManager stub that avoids Flask dependencies."""
|
||||
|
||||
def __init__(self):
|
||||
# Skip parent initialization to avoid starting timers or accessing Flask globals.
|
||||
pass
|
||||
|
||||
|
||||
def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=app_mode,
|
||||
workflow_id=f"{app_mode.value}-workflow-id",
|
||||
)
|
||||
|
||||
|
||||
def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity:
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id="workflow-task",
|
||||
app_config=_build_workflow_app_config(AppMode.WORKFLOW),
|
||||
inputs={"topic": "serialization"},
|
||||
files=[],
|
||||
user_id="user-workflow",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=1,
|
||||
trace_manager=trace_manager,
|
||||
workflow_execution_id="workflow-exec-id",
|
||||
extras={"external_trace_id": "trace-id"},
|
||||
)
|
||||
|
||||
|
||||
def _create_advanced_chat_generate_entity(
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> AdvancedChatAppGenerateEntity:
|
||||
return AdvancedChatAppGenerateEntity(
|
||||
task_id="advanced-task",
|
||||
app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT),
|
||||
conversation_id="conversation-id",
|
||||
inputs={"topic": "roundtrip"},
|
||||
files=[],
|
||||
user_id="user-advanced",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
query="Explain serialization",
|
||||
extras={"auto_generate_conversation_name": True},
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id="workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager():
|
||||
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
|
||||
serialized = entity.model_dump_json()
|
||||
payload = json.loads(serialized)
|
||||
|
||||
assert "trace_manager" not in payload
|
||||
|
||||
restored = WorkflowAppGenerateEntity.model_validate_json(serialized)
|
||||
|
||||
assert restored.model_dump() == entity.model_dump()
|
||||
assert restored.trace_manager is None
|
||||
|
||||
|
||||
def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager():
|
||||
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
|
||||
serialized = entity.model_dump_json()
|
||||
payload = json.loads(serialized)
|
||||
|
||||
assert "trace_manager" not in payload
|
||||
|
||||
restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized)
|
||||
|
||||
assert restored.model_dump() == entity.model_dump()
|
||||
assert restored.trace_manager is None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResumptionContextCase:
|
||||
name: str
|
||||
context_factory: Callable[[], tuple[WorkflowResumptionContext, type]]
|
||||
|
||||
|
||||
def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]:
|
||||
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
context = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=json.dumps({"state": "workflow"}),
|
||||
generate_entity=_WorkflowGenerateEntityWrapper(entity=entity),
|
||||
)
|
||||
return context, WorkflowAppGenerateEntity
|
||||
|
||||
|
||||
def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]:
|
||||
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
context = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=json.dumps({"state": "advanced"}),
|
||||
generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity),
|
||||
)
|
||||
return context, AdvancedChatAppGenerateEntity
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"),
|
||||
pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"),
|
||||
],
|
||||
)
|
||||
def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase):
|
||||
context, expected_type = case.context_factory()
|
||||
|
||||
serialized = context.dumps()
|
||||
restored = WorkflowResumptionContext.loads(serialized)
|
||||
|
||||
assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state
|
||||
entity = restored.get_generate_entity()
|
||||
assert isinstance(entity, expected_type)
|
||||
assert entity.model_dump() == context.get_generate_entity().model_dump()
|
||||
assert entity.trace_manager is None
|
||||
@ -0,0 +1,72 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker):
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.ADVANCED_CHAT
|
||||
app.workflow = workflow
|
||||
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
SimpleNamespace(engine=MagicMock()),
|
||||
)
|
||||
generator_spy = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_chat_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
conversation_id="conv-1",
|
||||
query="hello",
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
|
||||
|
||||
def test_invoke_workflow_app_injects_pause_state_config(mocker):
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.WORKFLOW
|
||||
app.workflow = workflow
|
||||
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
SimpleNamespace(engine=MagicMock()),
|
||||
)
|
||||
generator_spy = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_workflow_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
@ -0,0 +1,574 @@
|
||||
"""Unit tests for HumanInputFormRepositoryImpl private helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormRepositoryImpl,
|
||||
HumanInputFormSubmissionRepository,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
|
||||
|
||||
def _build_repository() -> HumanInputFormRepositoryImpl:
|
||||
return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id")
|
||||
|
||||
|
||||
def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]:
|
||||
created: list[SimpleNamespace] = []
|
||||
|
||||
def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def]
|
||||
recipient = SimpleNamespace(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=payload.TYPE,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
)
|
||||
created.append(recipient)
|
||||
return recipient
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new))
|
||||
return created
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Avoid SQLAlchemy mapper configuration in tests using fake sessions."""
|
||||
|
||||
class _FakeSelect:
|
||||
def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
|
||||
return self
|
||||
|
||||
def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
|
||||
return self
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.human_input_repository.selectinload", lambda *args, **kwargs: "_loader_option"
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *args, **kwargs: _FakeSelect())
|
||||
|
||||
|
||||
class TestHumanInputFormRepositoryImplHelpers:
|
||||
def test_build_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert restrict_to_user_ids == ["member-1"]
|
||||
return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(user_id="member-1"),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 2
|
||||
member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER)
|
||||
external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL)
|
||||
|
||||
member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload)
|
||||
assert member_payload.user_id == "member-1"
|
||||
assert member_payload.email == "member@example.com"
|
||||
|
||||
external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload)
|
||||
assert external_payload.email == "external@example.com"
|
||||
|
||||
def test_build_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
created = _patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert restrict_to_user_ids == ["missing-member"]
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(user_id="missing-member"),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 1
|
||||
assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL
|
||||
assert len(created) == 1 # only external recipient created via factory
|
||||
|
||||
def test_build_email_recipients_whole_workspace_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
return [
|
||||
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=True,
|
||||
items=[],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 2
|
||||
emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients}
|
||||
assert emails == {"member1@example.com", "member2@example.com"}
|
||||
|
||||
def test_build_email_recipients_dedupes_external_by_email(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
created = _patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert restrict_to_user_ids == []
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 1
|
||||
assert len(created) == 1
|
||||
|
||||
def test_build_email_recipients_prefers_member_over_external_by_email(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert restrict_to_user_ids == ["member-1"]
|
||||
return [_WorkspaceMemberInfo(user_id="member-1", email="shared@example.com")]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(user_id="member-1"),
|
||||
ExternalRecipient(email="shared@example.com"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 1
|
||||
assert recipients[0].recipient_type == RecipientType.EMAIL_MEMBER
|
||||
|
||||
def test_delivery_method_to_model_includes_external_recipients_with_whole_workspace(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
return [
|
||||
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query)
|
||||
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=True,
|
||||
items=[ExternalRecipient(email="external@example.com")],
|
||||
),
|
||||
subject="subject",
|
||||
body="body",
|
||||
)
|
||||
)
|
||||
|
||||
result = repo._delivery_method_to_model(session=session_stub, form_id="form-id", delivery_method=method)
|
||||
|
||||
assert len(result.recipients) == 3
|
||||
member_emails = {
|
||||
EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email
|
||||
for r in result.recipients
|
||||
if r.recipient_type == RecipientType.EMAIL_MEMBER
|
||||
}
|
||||
assert member_emails == {"member1@example.com", "member2@example.com"}
|
||||
external_payload = EmailExternalRecipientPayload.model_validate_json(
|
||||
next(r for r in result.recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL).recipient_payload
|
||||
)
|
||||
assert external_payload.email == "external@example.com"
|
||||
|
||||
|
||||
def _make_form_definition() -> str:
|
||||
return FormDefinition(
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=datetime.utcnow(),
|
||||
).model_dump_json()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyForm:
|
||||
id: str
|
||||
workflow_run_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_definition: str
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
|
||||
selected_action_id: str | None = None
|
||||
submitted_data: str | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
completed_by_recipient_id: str | None = None
|
||||
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyRecipient:
|
||||
id: str
|
||||
form_id: str
|
||||
recipient_type: RecipientType
|
||||
access_token: str
|
||||
form: _DummyForm | None = None
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def first(self):
|
||||
if isinstance(self._obj, list):
|
||||
return self._obj[0] if self._obj else None
|
||||
return self._obj
|
||||
|
||||
def all(self):
|
||||
if isinstance(self._obj, list):
|
||||
return list(self._obj)
|
||||
if self._obj is None:
|
||||
return []
|
||||
return [self._obj]
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scalars_result=None,
|
||||
scalars_results: list[object] | None = None,
|
||||
forms: dict[str, _DummyForm] | None = None,
|
||||
recipients: dict[str, _DummyRecipient] | None = None,
|
||||
):
|
||||
if scalars_results is not None:
|
||||
self._scalars_queue = list(scalars_results)
|
||||
elif scalars_result is not None:
|
||||
self._scalars_queue = [scalars_result]
|
||||
else:
|
||||
self._scalars_queue = []
|
||||
self.forms = forms or {}
|
||||
self.recipients = recipients or {}
|
||||
|
||||
def scalars(self, _query):
|
||||
if self._scalars_queue:
|
||||
result = self._scalars_queue.pop(0)
|
||||
else:
|
||||
result = None
|
||||
return _FakeScalarResult(result)
|
||||
|
||||
def get(self, model_cls, obj_id): # type: ignore[no-untyped-def]
|
||||
if getattr(model_cls, "__name__", None) == "HumanInputForm":
|
||||
return self.forms.get(obj_id)
|
||||
if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient":
|
||||
return self.recipients.get(obj_id)
|
||||
return None
|
||||
|
||||
def add(self, _obj):
|
||||
return None
|
||||
|
||||
def flush(self):
|
||||
return None
|
||||
|
||||
def refresh(self, _obj):
|
||||
return None
|
||||
|
||||
def begin(self):
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _session_factory(session: _FakeSession):
|
||||
class _SessionContext:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
def _factory(*_args, **_kwargs):
|
||||
return _SessionContext()
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
class TestHumanInputFormRepositoryImplPublicMethods:
|
||||
def test_get_form_returns_entity_and_recipients(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, [recipient]])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
entity = repo.get_form(form.workflow_run_id, form.node_id)
|
||||
|
||||
assert entity is not None
|
||||
assert entity.id == form.id
|
||||
assert entity.web_app_token == "token-123"
|
||||
assert len(entity.recipients) == 1
|
||||
assert entity.recipients[0].token == "token-123"
|
||||
|
||||
def test_get_form_returns_none_when_missing(self):
|
||||
session = _FakeSession(scalars_results=[None])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
assert repo.get_form("run-1", "node-1") is None
|
||||
|
||||
def test_get_form_returns_unsubmitted_state(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, []])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
entity = repo.get_form(form.workflow_run_id, form.node_id)
|
||||
|
||||
assert entity is not None
|
||||
assert entity.submitted is False
|
||||
assert entity.selected_action_id is None
|
||||
assert entity.submitted_data is None
|
||||
|
||||
def test_get_form_returns_submission_when_completed(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
selected_action_id="approve",
|
||||
submitted_data='{"field": "value"}',
|
||||
submitted_at=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, []])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
entity = repo.get_form(form.workflow_run_id, form.node_id)
|
||||
|
||||
assert entity is not None
|
||||
assert entity.submitted is True
|
||||
assert entity.selected_action_id == "approve"
|
||||
assert entity.submitted_data == {"field": "value"}
|
||||
|
||||
|
||||
class TestHumanInputFormSubmissionRepository:
|
||||
def test_get_by_token_returns_record(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
form=form,
|
||||
)
|
||||
session = _FakeSession(scalars_result=recipient)
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record = repo.get_by_token("token-123")
|
||||
|
||||
assert record is not None
|
||||
assert record.form_id == form.id
|
||||
assert record.recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
assert record.submitted is False
|
||||
|
||||
def test_get_by_form_id_and_recipient_type_uses_recipient(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
form=form,
|
||||
)
|
||||
session = _FakeSession(scalars_result=recipient)
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record = repo.get_by_form_id_and_recipient_type(
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
)
|
||||
|
||||
assert record is not None
|
||||
assert record.recipient_id == recipient.id
|
||||
assert record.access_token == recipient.access_token
|
||||
|
||||
def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch):
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=fixed_now,
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id="form-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
)
|
||||
session = _FakeSession(
|
||||
forms={form.id: form},
|
||||
recipients={recipient.id: recipient},
|
||||
)
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record: HumanInputFormRecord = repo.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=recipient.id,
|
||||
selected_action_id="approve",
|
||||
form_data={"field": "value"},
|
||||
submission_user_id="user-1",
|
||||
submission_end_user_id="end-user-1",
|
||||
)
|
||||
|
||||
assert form.selected_action_id == "approve"
|
||||
assert form.completed_by_recipient_id == recipient.id
|
||||
assert form.submission_user_id == "user-1"
|
||||
assert form.submission_end_user_id == "end-user-1"
|
||||
assert form.submitted_at == fixed_now
|
||||
assert record.submitted is True
|
||||
assert record.selected_action_id == "approve"
|
||||
assert record.submitted_data == {"field": "value"}
|
||||
@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
|
||||
|
||||
def test_ensure_no_human_input_nodes_passes_for_non_human_input():
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start_node",
|
||||
"data": {"type": "start"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
|
||||
|
||||
|
||||
def test_ensure_no_human_input_nodes_raises_for_human_input():
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "human_input_node",
|
||||
"data": {"type": "human-input"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
@ -55,6 +55,43 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
assert exc_info.value.args == ("oops",)
|
||||
|
||||
|
||||
def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch):
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
generate_mock = MagicMock(return_value={"data": {}})
|
||||
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
list(tool.invoke("test_user", {}))
|
||||
|
||||
call_kwargs = generate_mock.call_args.kwargs
|
||||
assert "pause_state_config" in call_kwargs
|
||||
assert call_kwargs["pause_state_config"] is None
|
||||
|
||||
|
||||
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that WorkflowTool should generate variable messages when there are outputs"""
|
||||
entity = ToolEntity(
|
||||
|
||||
@ -118,7 +118,6 @@ class TestGraphRuntimeState:
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
|
||||
assert isinstance(queue, InMemoryReadyQueue)
|
||||
assert state.ready_queue is queue
|
||||
|
||||
def test_graph_execution_lazy_instantiation(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
|
||||
@ -0,0 +1,88 @@
|
||||
"""
|
||||
Tests for PauseReason discriminated union serialization/deserialization.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.workflow.entities.pause_reason import (
|
||||
HumanInputRequired,
|
||||
PauseReason,
|
||||
SchedulingPause,
|
||||
)
|
||||
|
||||
|
||||
class _Holder(BaseModel):
|
||||
"""Helper model that embeds PauseReason for union tests."""
|
||||
|
||||
reason: PauseReason
|
||||
|
||||
|
||||
class TestPauseReasonDiscriminator:
|
||||
"""Test suite for PauseReason union discriminator."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("dict_value", "expected"),
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
"reason": {
|
||||
"TYPE": "human_input_required",
|
||||
"form_id": "form_id",
|
||||
"form_content": "form_content",
|
||||
"node_id": "node_id",
|
||||
"node_title": "node_title",
|
||||
},
|
||||
},
|
||||
HumanInputRequired(
|
||||
form_id="form_id",
|
||||
form_content="form_content",
|
||||
node_id="node_id",
|
||||
node_title="node_title",
|
||||
),
|
||||
id="HumanInputRequired",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"reason": {
|
||||
"TYPE": "scheduled_pause",
|
||||
"message": "Hold on",
|
||||
}
|
||||
},
|
||||
SchedulingPause(message="Hold on"),
|
||||
id="SchedulingPause",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_model_validate(self, dict_value, expected):
|
||||
"""Ensure scheduled pause payloads with lowercase TYPE deserialize."""
|
||||
holder = _Holder.model_validate(dict_value)
|
||||
|
||||
assert type(holder.reason) == type(expected)
|
||||
assert holder.reason == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reason",
|
||||
[
|
||||
HumanInputRequired(
|
||||
form_id="form_id",
|
||||
form_content="form_content",
|
||||
node_id="node_id",
|
||||
node_title="node_title",
|
||||
),
|
||||
SchedulingPause(message="Hold on"),
|
||||
],
|
||||
ids=lambda x: type(x).__name__,
|
||||
)
|
||||
def test_model_construct(self, reason):
|
||||
holder = _Holder(reason=reason)
|
||||
assert holder.reason == reason
|
||||
|
||||
def test_model_construct_with_invalid_type(self):
|
||||
with pytest.raises(ValidationError):
|
||||
holder = _Holder(reason=object()) # type: ignore
|
||||
|
||||
def test_unknown_type_fails_validation(self):
|
||||
"""Unknown TYPE values should raise a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
_Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}})
|
||||
@ -0,0 +1,131 @@
|
||||
"""Utilities for testing HumanInputNode without database dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class _InMemoryFormRecipient(HumanInputFormRecipientEntity):
|
||||
"""Minimal recipient entity required by the repository interface."""
|
||||
|
||||
def __init__(self, recipient_id: str, token: str) -> None:
|
||||
self._id = recipient_id
|
||||
self._token = token
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
return self._token
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
token: str | None = None
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
is_submitted: bool = False
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now()
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return self.token
|
||||
|
||||
@property
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
"""Pure in-memory repository used by workflow graph engine tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._form_counter = 0
|
||||
self.created_params: list[FormCreateParams] = []
|
||||
self.created_forms: list[_InMemoryFormEntity] = []
|
||||
self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {}
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
self.created_params.append(params)
|
||||
self._form_counter += 1
|
||||
form_id = f"form-{self._form_counter}"
|
||||
token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}"
|
||||
entity = _InMemoryFormEntity(
|
||||
form_id=form_id,
|
||||
rendered=params.rendered_content,
|
||||
token=token,
|
||||
)
|
||||
self.created_forms.append(entity)
|
||||
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity
|
||||
return entity
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
return self._forms_by_key.get((workflow_execution_id, node_id))
|
||||
|
||||
# Convenience helpers for tests -------------------------------------
|
||||
|
||||
def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None:
|
||||
"""Simulate a human submission for the next repository lookup."""
|
||||
|
||||
if not self.created_forms:
|
||||
raise AssertionError("no form has been created to attach submission data")
|
||||
entity = self.created_forms[-1]
|
||||
entity.action_id = action_id
|
||||
entity.data = form_data or {}
|
||||
entity.is_submitted = True
|
||||
entity.status_value = HumanInputFormStatus.SUBMITTED
|
||||
entity.expiration = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
def clear_submission(self) -> None:
|
||||
if not self.created_forms:
|
||||
return
|
||||
for form in self.created_forms:
|
||||
form.action_id = None
|
||||
form.data = None
|
||||
form.is_submitted = False
|
||||
form.status_value = HumanInputFormStatus.WAITING
|
||||
@ -0,0 +1,74 @@
|
||||
import queue
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher
|
||||
from core.workflow.graph_events import NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
class StubExecutionCoordinator:
|
||||
def __init__(self, paused: bool) -> None:
|
||||
self._paused = paused
|
||||
self.mark_complete_called = False
|
||||
self.failed_error: Exception | None = None
|
||||
|
||||
@property
|
||||
def aborted(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def paused(self) -> bool:
|
||||
return self._paused
|
||||
|
||||
@property
|
||||
def execution_complete(self) -> bool:
|
||||
return False
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
return None
|
||||
|
||||
def process_commands(self) -> None:
|
||||
return None
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
self.mark_complete_called = True
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
self.failed_error = error
|
||||
|
||||
|
||||
class StubEventHandler:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[object] = []
|
||||
|
||||
def dispatch(self, event: object) -> None:
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
def test_dispatcher_drains_events_when_paused() -> None:
|
||||
event_queue: queue.Queue = queue.Queue()
|
||||
event = NodeRunSucceededEvent(
|
||||
id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type=NodeType.START,
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
|
||||
)
|
||||
event_queue.put(event)
|
||||
|
||||
handler = StubEventHandler()
|
||||
coordinator = StubExecutionCoordinator(paused=True)
|
||||
dispatcher = Dispatcher(
|
||||
event_queue=event_queue,
|
||||
event_handler=handler,
|
||||
execution_coordinator=coordinator,
|
||||
event_emitter=None,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
|
||||
assert handler.events == [event]
|
||||
assert coordinator.mark_complete_called is True
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
@ -48,3 +50,13 @@ def test_handle_pause_noop_when_execution_running() -> None:
|
||||
|
||||
worker_pool.stop.assert_not_called()
|
||||
state_manager.clear_executing.assert_not_called()
|
||||
|
||||
|
||||
def test_has_executing_nodes_requires_pause() -> None:
|
||||
graph_execution = GraphExecution(workflow_id="workflow")
|
||||
graph_execution.start()
|
||||
|
||||
coordinator, _, _ = _build_coordinator(graph_execution)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
coordinator.has_executing_nodes()
|
||||
|
||||
@ -0,0 +1,189 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_llm_node(
|
||||
*,
|
||||
node_id: str,
|
||||
runtime_state: GraphRuntimeState,
|
||||
graph_init_params: GraphInitParams,
|
||||
mock_config: MockConfig,
|
||||
) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=f"LLM {node_id}",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=f"Prompt {node_id}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
llm_config = {"id": node_id, "data": llm_data.model_dump()}
|
||||
return MockLLMNode(
|
||||
id=llm_config["id"],
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
mock_config = MockConfig()
|
||||
llm_a = _build_llm_node(
|
||||
node_id="llm_a",
|
||||
runtime_state=runtime_state,
|
||||
graph_init_params=graph_init_params,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_b = _build_llm_node(
|
||||
node_id="llm_b",
|
||||
runtime_state=runtime_state,
|
||||
graph_init_params=graph_init_params,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(title="End", outputs=[], desc=None)
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
id=end_config["id"],
|
||||
config=end_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
builder = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_a, from_node_id="start")
|
||||
.add_node(llm_b, from_node_id="start")
|
||||
.add_node(end_node, from_node_id="llm_a")
|
||||
)
|
||||
return builder.connect(tail="llm_b", head="end").build()
|
||||
|
||||
|
||||
def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]:
|
||||
return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()}
|
||||
|
||||
|
||||
def test_runtime_state_snapshot_restores_graph_states() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
graph = _build_graph(runtime_state)
|
||||
runtime_state.attach_graph(graph)
|
||||
|
||||
graph.nodes["llm_a"].state = NodeState.TAKEN
|
||||
graph.nodes["llm_b"].state = NodeState.SKIPPED
|
||||
|
||||
for edge in graph.edges.values():
|
||||
if edge.tail == "start" and edge.head == "llm_a":
|
||||
edge.state = NodeState.TAKEN
|
||||
elif edge.tail == "start" and edge.head == "llm_b":
|
||||
edge.state = NodeState.SKIPPED
|
||||
elif edge.head == "end" and edge.tail == "llm_a":
|
||||
edge.state = NodeState.TAKEN
|
||||
elif edge.head == "end" and edge.tail == "llm_b":
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resumed_graph = _build_graph(resumed_state)
|
||||
resumed_state.attach_graph(resumed_graph)
|
||||
|
||||
assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN
|
||||
assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED
|
||||
assert _edge_state_map(resumed_graph) == _edge_state_map(graph)
|
||||
|
||||
|
||||
def test_join_readiness_uses_restored_edge_states() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
graph = _build_graph(runtime_state)
|
||||
runtime_state.attach_graph(graph)
|
||||
|
||||
ready_queue = InMemoryReadyQueue()
|
||||
state_manager = GraphStateManager(graph, ready_queue)
|
||||
|
||||
for edge in graph.get_incoming_edges("end"):
|
||||
if edge.tail == "llm_a":
|
||||
edge.state = NodeState.TAKEN
|
||||
if edge.tail == "llm_b":
|
||||
edge.state = NodeState.UNKNOWN
|
||||
|
||||
assert state_manager.is_node_ready("end") is False
|
||||
|
||||
for edge in graph.get_incoming_edges("end"):
|
||||
if edge.tail == "llm_b":
|
||||
edge.state = NodeState.TAKEN
|
||||
|
||||
assert state_manager.is_node_ready("end") is True
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resumed_graph = _build_graph(resumed_state)
|
||||
resumed_state.attach_graph(resumed_graph)
|
||||
|
||||
resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue())
|
||||
assert resumed_state_manager.is_node_ready("end") is True
|
||||
@ -1,5 +1,7 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -14,11 +16,12 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
@ -28,15 +31,21 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
def _build_branching_graph(
|
||||
mock_config: MockConfig,
|
||||
form_repository: HumanInputFormRepository,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
@ -49,12 +58,18 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
if graph_runtime_state is None:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="test-execution-id",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
@ -93,15 +108,21 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="primary", title="Primary"),
|
||||
UserAction(id="secondary", title="Secondary"),
|
||||
],
|
||||
)
|
||||
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
@ -219,8 +240,18 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
for scenario in branch_scenarios:
|
||||
runner = TableTestRunner()
|
||||
|
||||
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_branching_graph(mock_config)
|
||||
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_create_repo.get_form.return_value = None
|
||||
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_form_entity.rendered_content = "rendered"
|
||||
mock_form_entity.submitted = False
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_branching_graph(mock_config, mock_create_repo)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause before branching decision",
|
||||
@ -242,23 +273,16 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
assert initial_result.success, initial_result.event_mismatch_details
|
||||
assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events)
|
||||
|
||||
graph_runtime_state = initial_result.graph_runtime_state
|
||||
graph = initial_result.graph
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"])
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
|
||||
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
|
||||
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
|
||||
expected_pre_chunk_events_in_resumption = [
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
]
|
||||
|
||||
expected_resume_sequence: list[type] = (
|
||||
[
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
]
|
||||
expected_pre_chunk_events_in_resumption
|
||||
+ [NodeRunStreamChunkEvent] * pre_chunk_count
|
||||
+ [
|
||||
NodeRunSucceededEvent,
|
||||
@ -273,11 +297,25 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
]
|
||||
)
|
||||
|
||||
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
submitted_form = MagicMock(spec=HumanInputFormEntity)
|
||||
submitted_form.id = mock_form_entity.id
|
||||
submitted_form.web_app_token = mock_form_entity.web_app_token
|
||||
submitted_form.recipients = []
|
||||
submitted_form.rendered_content = mock_form_entity.rendered_content
|
||||
submitted_form.submitted = True
|
||||
submitted_form.selected_action_id = scenario["handle"]
|
||||
submitted_form.submitted_data = {}
|
||||
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
|
||||
mock_get_repo.get_form.return_value = submitted_form
|
||||
|
||||
def resume_graph_factory(
|
||||
graph_snapshot: Graph = graph,
|
||||
state_snapshot: GraphRuntimeState = graph_runtime_state,
|
||||
initial_result=initial_result, mock_get_repo=mock_get_repo
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
return graph_snapshot, state_snapshot
|
||||
assert initial_result.graph_runtime_state is not None
|
||||
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
|
||||
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
|
||||
return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state)
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description=f"HumanInput resumes via {scenario['handle']} branch",
|
||||
@ -321,7 +359,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
|
||||
]
|
||||
assert pre_indices == list(range(2, 2 + pre_chunk_count))
|
||||
expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption)
|
||||
assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index))
|
||||
|
||||
resume_chunk_indices = [
|
||||
index
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import datetime
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -13,11 +15,12 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
@ -27,15 +30,21 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
def _build_llm_human_llm_graph(
|
||||
mock_config: MockConfig,
|
||||
form_repository: HumanInputFormRepository,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
@ -48,12 +57,15 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
if graph_runtime_state is None:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id,"
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
@ -92,15 +104,21 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="accept", title="Accept"),
|
||||
UserAction(id="reject", title="Reject"),
|
||||
],
|
||||
)
|
||||
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
|
||||
@ -130,7 +148,7 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
.add_root(start_node)
|
||||
.add_node(llm_first)
|
||||
.add_node(human_node)
|
||||
.add_node(llm_second)
|
||||
.add_node(llm_second, source_handle="accept")
|
||||
.add_node(end_node)
|
||||
.build()
|
||||
)
|
||||
@ -167,8 +185,18 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
GraphRunPausedEvent, # graph run pauses awaiting resume
|
||||
]
|
||||
|
||||
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_create_repo.get_form.return_value = None
|
||||
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_form_entity.rendered_content = "rendered"
|
||||
mock_form_entity.submitted = False
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_llm_human_llm_graph(mock_config)
|
||||
return _build_llm_human_llm_graph(mock_config, mock_create_repo)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause preserves LLM streaming order",
|
||||
@ -210,6 +238,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
expected_resume_sequence: list[type] = [
|
||||
GraphRunStartedEvent, # resumed graph run begins
|
||||
NodeRunStartedEvent, # human node restarts
|
||||
# Form Filled should be generated first, then the node execution ends and stream chunk is generated.
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
|
||||
NodeRunStreamChunkEvent, # cached llm_initial final chunk
|
||||
@ -225,12 +255,27 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
GraphRunSucceededEvent, # graph run succeeds after resume
|
||||
]
|
||||
|
||||
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
submitted_form = MagicMock(spec=HumanInputFormEntity)
|
||||
submitted_form.id = mock_form_entity.id
|
||||
submitted_form.web_app_token = mock_form_entity.web_app_token
|
||||
submitted_form.recipients = []
|
||||
submitted_form.rendered_content = mock_form_entity.rendered_content
|
||||
submitted_form.submitted = True
|
||||
submitted_form.selected_action_id = "accept"
|
||||
submitted_form.submitted_data = {}
|
||||
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
|
||||
mock_get_repo.get_form.return_value = submitted_form
|
||||
|
||||
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
return graph, graph_runtime_state
|
||||
# restruct the graph runtime state
|
||||
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
|
||||
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
|
||||
return _build_llm_human_llm_graph(
|
||||
mock_config,
|
||||
mock_get_repo,
|
||||
resume_runtime_state,
|
||||
)
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description="HumanInput resume continues LLM streaming order",
|
||||
|
||||
@ -0,0 +1,270 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.config import GraphEngineConfig
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class PauseStateStore(Protocol):
|
||||
def save(self, runtime_state: GraphRuntimeState) -> None: ...
|
||||
|
||||
def load(self) -> GraphRuntimeState: ...
|
||||
|
||||
|
||||
class InMemoryPauseStore:
|
||||
def __init__(self) -> None:
|
||||
self._snapshot: str | None = None
|
||||
|
||||
def save(self, runtime_state: GraphRuntimeState) -> None:
|
||||
self._snapshot = runtime_state.dumps()
|
||||
|
||||
def load(self) -> GraphRuntimeState:
|
||||
assert self._snapshot is not None
|
||||
return GraphRuntimeState.from_snapshot(self._snapshot)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticForm(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
is_submitted: bool
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return "token"
|
||||
|
||||
@property
|
||||
def recipients(self) -> list:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class StaticRepo(HumanInputFormRepository):
|
||||
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
|
||||
self._forms_by_node_id = dict(forms_by_node_id)
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
return self._forms_by_node_id.get(node_id)
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
raise AssertionError("create_form should not be called in resume scenario")
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
|
||||
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
|
||||
human_a = HumanInputNode(
|
||||
id=human_a_config["id"],
|
||||
config=human_a_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
|
||||
human_b = HumanInputNode(
|
||||
id=human_b_config["id"],
|
||||
config=human_b_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="End",
|
||||
outputs=[
|
||||
OutputVariableEntity(variable="res_a", value_selector=["human_a", "__action_id"]),
|
||||
OutputVariableEntity(variable="res_b", value_selector=["human_b", "__action_id"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
id=end_config["id"],
|
||||
config=end_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
builder = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_a, from_node_id="start")
|
||||
.add_node(human_b, from_node_id="start")
|
||||
.add_node(end_node, from_node_id="human_a", source_handle="approve")
|
||||
)
|
||||
return builder.connect(tail="human_b", head="end", source_handle="approve").build()
|
||||
|
||||
|
||||
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[object]:
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
return list(engine.run())
|
||||
|
||||
|
||||
def _form(submitted: bool, action_id: str | None) -> StaticForm:
|
||||
return StaticForm(
|
||||
form_id="form",
|
||||
rendered="rendered",
|
||||
is_submitted=submitted,
|
||||
action_id=action_id,
|
||||
data={},
|
||||
status_value=HumanInputFormStatus.SUBMITTED if submitted else HumanInputFormStatus.WAITING,
|
||||
)
|
||||
|
||||
|
||||
def test_parallel_human_input_join_completes_after_second_resume() -> None:
|
||||
pause_store: PauseStateStore = InMemoryPauseStore()
|
||||
|
||||
initial_state = _build_runtime_state()
|
||||
initial_repo = StaticRepo(
|
||||
{
|
||||
"human_a": _form(submitted=False, action_id=None),
|
||||
"human_b": _form(submitted=False, action_id=None),
|
||||
}
|
||||
)
|
||||
initial_graph = _build_graph(initial_state, initial_repo)
|
||||
initial_events = _run_graph(initial_graph, initial_state)
|
||||
|
||||
assert isinstance(initial_events[-1], GraphRunPausedEvent)
|
||||
pause_store.save(initial_state)
|
||||
|
||||
first_resume_state = pause_store.load()
|
||||
first_resume_repo = StaticRepo(
|
||||
{
|
||||
"human_a": _form(submitted=True, action_id="approve"),
|
||||
"human_b": _form(submitted=False, action_id=None),
|
||||
}
|
||||
)
|
||||
first_resume_graph = _build_graph(first_resume_state, first_resume_repo)
|
||||
first_resume_events = _run_graph(first_resume_graph, first_resume_state)
|
||||
|
||||
assert isinstance(first_resume_events[0], GraphRunStartedEvent)
|
||||
assert first_resume_events[0].reason is WorkflowStartReason.RESUMPTION
|
||||
assert isinstance(first_resume_events[-1], GraphRunPausedEvent)
|
||||
pause_store.save(first_resume_state)
|
||||
|
||||
second_resume_state = pause_store.load()
|
||||
second_resume_repo = StaticRepo(
|
||||
{
|
||||
"human_a": _form(submitted=True, action_id="approve"),
|
||||
"human_b": _form(submitted=True, action_id="approve"),
|
||||
}
|
||||
)
|
||||
second_resume_graph = _build_graph(second_resume_state, second_resume_repo)
|
||||
second_resume_events = _run_graph(second_resume_graph, second_resume_state)
|
||||
|
||||
assert isinstance(second_resume_events[0], GraphRunStartedEvent)
|
||||
assert second_resume_events[0].reason is WorkflowStartReason.RESUMPTION
|
||||
assert isinstance(second_resume_events[-1], GraphRunSucceededEvent)
|
||||
assert any(isinstance(event, NodeRunSucceededEvent) and event.node_id == "end" for event in second_resume_events)
|
||||
@ -0,0 +1,333 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.config import GraphEngineConfig
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig, NodeMockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticForm(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
is_submitted: bool
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return "token"
|
||||
|
||||
@property
|
||||
def recipients(self) -> list:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class StaticRepo(HumanInputFormRepository):
|
||||
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
|
||||
self._forms_by_node_id = dict(forms_by_node_id)
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
return self._forms_by_node_id.get(node_id)
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
raise AssertionError("create_form should not be called in resume scenario")
|
||||
|
||||
|
||||
class DelayedHumanInputNode(HumanInputNode):
|
||||
def __init__(self, delay_seconds: float, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._delay_seconds = delay_seconds
|
||||
|
||||
def _run(self):
|
||||
if self._delay_seconds > 0:
|
||||
time.sleep(self._delay_seconds)
|
||||
yield from super()._run()
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
|
||||
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
|
||||
human_a = HumanInputNode(
|
||||
id=human_a_config["id"],
|
||||
config=human_a_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
|
||||
human_b = DelayedHumanInputNode(
|
||||
id=human_b_config["id"],
|
||||
config=human_b_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
delay_seconds=0.2,
|
||||
)
|
||||
|
||||
llm_data = LLMNodeData(
|
||||
title="LLM A",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="Prompt A",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
structured_output_enabled=False,
|
||||
)
|
||||
llm_config = {"id": "llm_a", "data": llm_data.model_dump()}
|
||||
llm_a = MockLLMNode(
|
||||
id=llm_config["id"],
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_a, from_node_id="start")
|
||||
.add_node(human_b, from_node_id="start")
|
||||
.add_node(llm_a, from_node_id="human_a", source_handle="approve")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def test_parallel_human_input_pause_preserves_node_finished() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
|
||||
runtime_state.graph_execution.start()
|
||||
runtime_state.register_paused_node("human_a")
|
||||
runtime_state.register_paused_node("human_b")
|
||||
|
||||
submitted = StaticForm(
|
||||
form_id="form-a",
|
||||
rendered="rendered",
|
||||
is_submitted=True,
|
||||
action_id="approve",
|
||||
data={},
|
||||
status_value=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
pending = StaticForm(
|
||||
form_id="form-b",
|
||||
rendered="rendered",
|
||||
is_submitted=False,
|
||||
action_id=None,
|
||||
data=None,
|
||||
status_value=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
repo = StaticRepo({"human_a": submitted, "human_b": pending})
|
||||
|
||||
mock_config = MockConfig()
|
||||
mock_config.simulate_delays = True
|
||||
mock_config.set_node_config(
|
||||
"llm_a",
|
||||
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
|
||||
)
|
||||
|
||||
graph = _build_graph(runtime_state, repo, mock_config)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
|
||||
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
|
||||
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
|
||||
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
|
||||
graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
|
||||
assert graph_started
|
||||
assert graph_paused
|
||||
assert human_b_pause
|
||||
assert llm_started
|
||||
assert llm_succeeded
|
||||
|
||||
|
||||
def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None:
|
||||
base_state = _build_runtime_state()
|
||||
base_state.graph_execution.start()
|
||||
base_state.register_paused_node("human_a")
|
||||
base_state.register_paused_node("human_b")
|
||||
snapshot = base_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
submitted = StaticForm(
|
||||
form_id="form-a",
|
||||
rendered="rendered",
|
||||
is_submitted=True,
|
||||
action_id="approve",
|
||||
data={},
|
||||
status_value=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
pending = StaticForm(
|
||||
form_id="form-b",
|
||||
rendered="rendered",
|
||||
is_submitted=False,
|
||||
action_id=None,
|
||||
data=None,
|
||||
status_value=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
repo = StaticRepo({"human_a": submitted, "human_b": pending})
|
||||
|
||||
mock_config = MockConfig()
|
||||
mock_config.simulate_delays = True
|
||||
mock_config.set_node_config(
|
||||
"llm_a",
|
||||
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
|
||||
)
|
||||
|
||||
graph = _build_graph(resumed_state, repo, mock_config)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=resumed_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent))
|
||||
assert start_event.reason is WorkflowStartReason.RESUMPTION
|
||||
|
||||
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
|
||||
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
|
||||
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
|
||||
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
|
||||
|
||||
assert graph_paused
|
||||
assert human_b_pause
|
||||
assert llm_started
|
||||
assert llm_succeeded
|
||||
@ -0,0 +1,309 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.config import GraphEngineConfig
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig, NodeMockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticForm(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
is_submitted: bool
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return "token"
|
||||
|
||||
@property
|
||||
def recipients(self) -> list:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class StaticRepo(HumanInputFormRepository):
|
||||
def __init__(self, form: HumanInputFormEntity) -> None:
|
||||
self._form = form
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
if node_id != "human_pause":
|
||||
return None
|
||||
return self._form
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
raise AssertionError("create_form should not be called in this test")
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
llm_a_data = LLMNodeData(
|
||||
title="LLM A",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="Prompt A",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
structured_output_enabled=False,
|
||||
)
|
||||
llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()}
|
||||
llm_a = MockLLMNode(
|
||||
id=llm_a_config["id"],
|
||||
config=llm_a_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
llm_b_data = LLMNodeData(
|
||||
title="LLM B",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="Prompt B",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
structured_output_enabled=False,
|
||||
)
|
||||
llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()}
|
||||
llm_b = MockLLMNode(
|
||||
id=llm_b_config["id"],
|
||||
config=llm_b_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Pause here",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
human_config = {"id": "human_pause", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
end_human_data = EndNodeData(title="End Human", outputs=[], desc=None)
|
||||
end_human_config = {"id": "end_human", "data": end_human_data.model_dump()}
|
||||
end_human = EndNode(
|
||||
id=end_human_config["id"],
|
||||
config=end_human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_a, from_node_id="start")
|
||||
.add_node(human_node, from_node_id="start")
|
||||
.add_node(llm_b, from_node_id="llm_a")
|
||||
.add_node(end_human, from_node_id="human_pause", source_handle="approve")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None:
|
||||
for event in events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
|
||||
return event
|
||||
return None
|
||||
|
||||
|
||||
def test_pause_defers_ready_nodes_until_resume() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
|
||||
paused_form = StaticForm(
|
||||
form_id="form-pause",
|
||||
rendered="rendered",
|
||||
is_submitted=False,
|
||||
status_value=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
pause_repo = StaticRepo(paused_form)
|
||||
|
||||
mock_config = MockConfig()
|
||||
mock_config.simulate_delays = True
|
||||
mock_config.set_node_config(
|
||||
"llm_a",
|
||||
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
|
||||
)
|
||||
mock_config.set_node_config(
|
||||
"llm_b",
|
||||
NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0),
|
||||
)
|
||||
|
||||
graph = _build_graph(runtime_state, pause_repo, mock_config)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
|
||||
paused_events = list(engine.run())
|
||||
|
||||
assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events)
|
||||
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events)
|
||||
assert _get_node_started_event(paused_events, "llm_b") is None
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
submitted_form = StaticForm(
|
||||
form_id="form-pause",
|
||||
rendered="rendered",
|
||||
is_submitted=True,
|
||||
action_id="approve",
|
||||
data={},
|
||||
status_value=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
resume_repo = StaticRepo(submitted_form)
|
||||
|
||||
resumed_graph = _build_graph(resumed_state, resume_repo, mock_config)
|
||||
resumed_engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=resumed_graph,
|
||||
graph_runtime_state=resumed_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
|
||||
resumed_events = list(resumed_engine.run())
|
||||
|
||||
start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent))
|
||||
assert start_event.reason is WorkflowStartReason.RESUMPTION
|
||||
|
||||
llm_b_started = _get_node_started_event(resumed_events, "llm_b")
|
||||
assert llm_b_started is not None
|
||||
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events)
|
||||
@ -0,0 +1,217 @@
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.graph import GraphRunStartedEvent
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="test-execution-id",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = True
|
||||
form_entity.selected_action_id = action_id
|
||||
form_entity.submitted_data = {}
|
||||
form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
|
||||
repo.get_form.return_value = form_entity
|
||||
return repo
|
||||
|
||||
|
||||
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = False
|
||||
repo.create_form.return_value = form_entity
|
||||
repo.get_form.return_value = None
|
||||
return repo
|
||||
|
||||
|
||||
def _build_human_input_graph(
|
||||
runtime_state: GraphRuntimeState,
|
||||
form_repository: HumanInputFormRepository,
|
||||
) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="service-api",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": start_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="human",
|
||||
form_content="Awaiting human input",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="continue", title="Continue"),
|
||||
],
|
||||
)
|
||||
human_node = HumanInputNode(
|
||||
id="human",
|
||||
config={"id": "human", "data": human_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[
|
||||
OutputVariableEntity(variable="result", value_selector=["human", "action_id"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_node = EndNode(
|
||||
id="end",
|
||||
config={"id": "end", "data": end_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_node)
|
||||
.add_node(end_node, from_node_id="human", source_handle="continue")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]:
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
return list(engine.run())
|
||||
|
||||
|
||||
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
|
||||
return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)]
|
||||
|
||||
|
||||
def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None:
|
||||
for event in events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
|
||||
return event
|
||||
return None
|
||||
|
||||
|
||||
def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any:
|
||||
segment = variable_pool.get(selector)
|
||||
assert segment is not None
|
||||
return getattr(segment, "value", segment)
|
||||
|
||||
|
||||
def test_engine_resume_restores_state_and_completion():
|
||||
# Baseline run without pausing
|
||||
baseline_state = _build_runtime_state()
|
||||
baseline_repo = _mock_form_repository_with_submission(action_id="continue")
|
||||
baseline_graph = _build_human_input_graph(baseline_state, baseline_repo)
|
||||
baseline_events = _run_graph(baseline_graph, baseline_state)
|
||||
assert baseline_events
|
||||
first_paused_event = baseline_events[0]
|
||||
assert isinstance(first_paused_event, GraphRunStartedEvent)
|
||||
assert first_paused_event.reason is WorkflowStartReason.INITIAL
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_success_nodes = _node_successes(baseline_events)
|
||||
|
||||
# Run with pause
|
||||
paused_state = _build_runtime_state()
|
||||
pause_repo = _mock_form_repository_without_submission()
|
||||
paused_graph = _build_human_input_graph(paused_state, pause_repo)
|
||||
paused_events = _run_graph(paused_graph, paused_state)
|
||||
assert paused_events
|
||||
first_paused_event = paused_events[0]
|
||||
assert isinstance(first_paused_event, GraphRunStartedEvent)
|
||||
assert first_paused_event.reason is WorkflowStartReason.INITIAL
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
# Resume from snapshot
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resume_repo = _mock_form_repository_with_submission(action_id="continue")
|
||||
resumed_graph = _build_human_input_graph(resumed_state, resume_repo)
|
||||
resumed_events = _run_graph(resumed_graph, resumed_state)
|
||||
assert resumed_events
|
||||
first_resumed_event = resumed_events[0]
|
||||
assert isinstance(first_resumed_event, GraphRunStartedEvent)
|
||||
assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION
|
||||
assert isinstance(resumed_events[-1], GraphRunSucceededEvent)
|
||||
|
||||
combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events)
|
||||
assert combined_success_nodes == baseline_success_nodes
|
||||
|
||||
paused_human_started = _node_start_event(paused_events, "human")
|
||||
resumed_human_started = _node_start_event(resumed_events, "human")
|
||||
assert paused_human_started is not None
|
||||
assert resumed_human_started is not None
|
||||
assert paused_human_started.id == resumed_human_started.id
|
||||
|
||||
assert baseline_state.outputs == resumed_state.outputs
|
||||
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
|
||||
resumed_state.variable_pool, ("human", "__action_id")
|
||||
)
|
||||
assert baseline_state.graph_execution.completed
|
||||
assert resumed_state.graph_execution.completed
|
||||
@ -7,6 +7,7 @@ from core.workflow.nodes.base.node import Node
|
||||
# Ensures that all node classes are imported.
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed.
|
||||
_ = NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
|
||||
@ -45,7 +46,9 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
||||
assert isinstance(cls.node_type, NodeType)
|
||||
assert isinstance(node_version, str)
|
||||
node_type_and_version = (node_type, node_version)
|
||||
assert node_type_and_version not in type_version_set
|
||||
assert node_type_and_version not in type_version_set, (
|
||||
f"Duplicate node type and version for class: {cls=} {node_type_and_version=}"
|
||||
)
|
||||
type_version_set.add(node_type_and_version)
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1 @@
|
||||
# Unit tests for human input node
|
||||
@ -0,0 +1,16 @@
|
||||
from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
def test_render_body_template_replaces_variable_values():
|
||||
config = EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(),
|
||||
subject="Subject",
|
||||
body="Hello {{#node1.value#}} {{#url#}}",
|
||||
)
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(["node1", "value"], "World")
|
||||
|
||||
result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool)
|
||||
|
||||
assert result == "Hello World https://example.com"
|
||||
@ -0,0 +1,597 @@
|
||||
"""
|
||||
Unit tests for human input node entities.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.node_events import PauseRequestedEvent
|
||||
from core.workflow.node_events.node import StreamCompletedEvent
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormInput,
|
||||
FormInputDefault,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
WebAppDeliveryMethod,
|
||||
_WebAppDeliveryConfig,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
ButtonStyle,
|
||||
DeliveryMethodType,
|
||||
EmailRecipientType,
|
||||
FormInputType,
|
||||
PlaceholderType,
|
||||
TimeoutUnit,
|
||||
)
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository
|
||||
|
||||
|
||||
class TestDeliveryMethod:
|
||||
"""Test DeliveryMethod entity."""
|
||||
|
||||
def test_webapp_delivery_method(self):
|
||||
"""Test webapp delivery method creation."""
|
||||
delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())
|
||||
|
||||
assert delivery_method.type == DeliveryMethodType.WEBAPP
|
||||
assert delivery_method.enabled is True
|
||||
assert isinstance(delivery_method.config, _WebAppDeliveryConfig)
|
||||
|
||||
def test_email_delivery_method(self):
|
||||
"""Test email delivery method creation."""
|
||||
recipients = EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"),
|
||||
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"),
|
||||
],
|
||||
)
|
||||
|
||||
config = EmailDeliveryConfig(
|
||||
recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder"
|
||||
)
|
||||
|
||||
delivery_method = EmailDeliveryMethod(enabled=True, config=config)
|
||||
|
||||
assert delivery_method.type == DeliveryMethodType.EMAIL
|
||||
assert delivery_method.enabled is True
|
||||
assert isinstance(delivery_method.config, EmailDeliveryConfig)
|
||||
assert delivery_method.config.subject == "Test Subject"
|
||||
assert len(delivery_method.config.recipients.items) == 2
|
||||
|
||||
|
||||
class TestFormInput:
|
||||
"""Test FormInput entity."""
|
||||
|
||||
def test_text_input_with_constant_default(self):
|
||||
"""Test text input with constant default value."""
|
||||
default = FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter your response here...")
|
||||
|
||||
form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default)
|
||||
|
||||
assert form_input.type == FormInputType.TEXT_INPUT
|
||||
assert form_input.output_variable_name == "user_input"
|
||||
assert form_input.default.type == PlaceholderType.CONSTANT
|
||||
assert form_input.default.value == "Enter your response here..."
|
||||
|
||||
def test_text_input_with_variable_default(self):
|
||||
"""Test text input with variable default value."""
|
||||
default = FormInputDefault(type=PlaceholderType.VARIABLE, selector=["node_123", "output_var"])
|
||||
|
||||
form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default)
|
||||
|
||||
assert form_input.default.type == PlaceholderType.VARIABLE
|
||||
assert form_input.default.selector == ["node_123", "output_var"]
|
||||
|
||||
def test_form_input_without_default(self):
|
||||
"""Test form input without default value."""
|
||||
form_input = FormInput(type=FormInputType.PARAGRAPH, output_variable_name="description")
|
||||
|
||||
assert form_input.type == FormInputType.PARAGRAPH
|
||||
assert form_input.output_variable_name == "description"
|
||||
assert form_input.default is None
|
||||
|
||||
|
||||
class TestUserAction:
|
||||
"""Test UserAction entity."""
|
||||
|
||||
def test_user_action_creation(self):
|
||||
"""Test user action creation."""
|
||||
action = UserAction(id="approve", title="Approve", button_style=ButtonStyle.PRIMARY)
|
||||
|
||||
assert action.id == "approve"
|
||||
assert action.title == "Approve"
|
||||
assert action.button_style == ButtonStyle.PRIMARY
|
||||
|
||||
def test_user_action_default_button_style(self):
|
||||
"""Test user action with default button style."""
|
||||
action = UserAction(id="cancel", title="Cancel")
|
||||
|
||||
assert action.button_style == ButtonStyle.DEFAULT
|
||||
|
||||
def test_user_action_length_boundaries(self):
|
||||
"""Test user action id and title length boundaries."""
|
||||
action = UserAction(id="a" * 20, title="b" * 20)
|
||||
|
||||
assert action.id == "a" * 20
|
||||
assert action.title == "b" * 20
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field_name", "value"),
|
||||
[
|
||||
("id", "a" * 21),
|
||||
("title", "b" * 21),
|
||||
],
|
||||
)
|
||||
def test_user_action_length_limits(self, field_name: str, value: str):
|
||||
"""User action fields should enforce max length."""
|
||||
data = {"id": "approve", "title": "Approve"}
|
||||
data[field_name] = value
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
UserAction(**data)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
|
||||
|
||||
|
||||
class TestHumanInputNodeData:
|
||||
"""Test HumanInputNodeData entity."""
|
||||
|
||||
def test_valid_node_data_creation(self):
|
||||
"""Test creating valid human input node data."""
|
||||
delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())]
|
||||
|
||||
inputs = [
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="content",
|
||||
default=FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter content..."),
|
||||
)
|
||||
]
|
||||
|
||||
user_actions = [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)]
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input Test",
|
||||
desc="Test node description",
|
||||
delivery_methods=delivery_methods,
|
||||
form_content="# Test Form\n\nPlease provide input:\n\n{{#$output.content#}}",
|
||||
inputs=inputs,
|
||||
user_actions=user_actions,
|
||||
timeout=24,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
)
|
||||
|
||||
assert node_data.title == "Human Input Test"
|
||||
assert node_data.desc == "Test node description"
|
||||
assert len(node_data.delivery_methods) == 1
|
||||
assert node_data.form_content.startswith("# Test Form")
|
||||
assert len(node_data.inputs) == 1
|
||||
assert len(node_data.user_actions) == 1
|
||||
assert node_data.timeout == 24
|
||||
assert node_data.timeout_unit == TimeoutUnit.HOUR
|
||||
|
||||
def test_node_data_with_multiple_delivery_methods(self):
|
||||
"""Test node data with multiple delivery methods."""
|
||||
delivery_methods = [
|
||||
WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()),
|
||||
EmailDeliveryMethod(
|
||||
enabled=False, # Disabled method should be fine
|
||||
config=EmailDeliveryConfig(
|
||||
subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Test Node", delivery_methods=delivery_methods, timeout=1, timeout_unit=TimeoutUnit.DAY
|
||||
)
|
||||
|
||||
assert len(node_data.delivery_methods) == 2
|
||||
assert node_data.timeout == 1
|
||||
assert node_data.timeout_unit == TimeoutUnit.DAY
|
||||
|
||||
def test_node_data_defaults(self):
|
||||
"""Test node data with default values."""
|
||||
node_data = HumanInputNodeData(title="Test Node")
|
||||
|
||||
assert node_data.title == "Test Node"
|
||||
assert node_data.desc is None
|
||||
assert node_data.delivery_methods == []
|
||||
assert node_data.form_content == ""
|
||||
assert node_data.inputs == []
|
||||
assert node_data.user_actions == []
|
||||
assert node_data.timeout == 36
|
||||
assert node_data.timeout_unit == TimeoutUnit.HOUR
|
||||
|
||||
def test_duplicate_input_output_variable_name_raises_validation_error(self):
|
||||
"""Duplicate form input output_variable_name should raise validation error."""
|
||||
duplicate_inputs = [
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"):
|
||||
HumanInputNodeData(title="Test Node", inputs=duplicate_inputs)
|
||||
|
||||
def test_duplicate_user_action_ids_raise_validation_error(self):
|
||||
"""Duplicate user action ids should raise validation error."""
|
||||
duplicate_actions = [
|
||||
UserAction(id="submit", title="Submit"),
|
||||
UserAction(id="submit", title="Submit Again"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValidationError, match="duplicated user action id 'submit'"):
|
||||
HumanInputNodeData(title="Test Node", user_actions=duplicate_actions)
|
||||
|
||||
def test_extract_outputs_field_names(self):
|
||||
content = r"""This is titile {{#start.title#}}
|
||||
|
||||
A content is required:
|
||||
|
||||
{{#$output.content#}}
|
||||
|
||||
A ending is required:
|
||||
|
||||
{{#$output.ending#}}
|
||||
"""
|
||||
|
||||
node_data = HumanInputNodeData(title="Human Input", form_content=content)
|
||||
field_names = node_data.outputs_field_names()
|
||||
assert field_names == ["content", "ending"]
|
||||
|
||||
|
||||
class TestRecipients:
|
||||
"""Test email recipient entities."""
|
||||
|
||||
def test_member_recipient(self):
|
||||
"""Test member recipient creation."""
|
||||
recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")
|
||||
|
||||
assert recipient.type == EmailRecipientType.MEMBER
|
||||
assert recipient.user_id == "user-123"
|
||||
|
||||
def test_external_recipient(self):
|
||||
"""Test external recipient creation."""
|
||||
recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com")
|
||||
|
||||
assert recipient.type == EmailRecipientType.EXTERNAL
|
||||
assert recipient.email == "test@example.com"
|
||||
|
||||
def test_email_recipients_whole_workspace(self):
|
||||
"""Test email recipients with whole workspace enabled."""
|
||||
recipients = EmailRecipients(
|
||||
whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")]
|
||||
)
|
||||
|
||||
assert recipients.whole_workspace is True
|
||||
assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True
|
||||
|
||||
def test_email_recipients_specific_users(self):
|
||||
"""Test email recipients with specific users."""
|
||||
recipients = EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"),
|
||||
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"),
|
||||
],
|
||||
)
|
||||
|
||||
assert recipients.whole_workspace is False
|
||||
assert len(recipients.items) == 2
|
||||
assert recipients.items[0].user_id == "user-123"
|
||||
assert recipients.items[1].email == "external@example.com"
|
||||
|
||||
|
||||
class TestHumanInputNodeVariableResolution:
|
||||
"""Tests for resolving variable-based defaults in HumanInputNode."""
|
||||
|
||||
def test_resolves_variable_defaults(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("start", "name"), "Jane Doe")
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Provide your name",
|
||||
inputs=[
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="user_name",
|
||||
default=FormInputDefault(type=PlaceholderType.VARIABLE, selector=["start", "name"]),
|
||||
),
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="user_email",
|
||||
default=FormInputDefault(type=PlaceholderType.CONSTANT, value="foo@example.com"),
|
||||
),
|
||||
],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
mock_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_repo.get_form.return_value = None
|
||||
mock_repo.create_form.return_value = SimpleNamespace(
|
||||
id="form-1",
|
||||
rendered_content="Provide your name",
|
||||
web_app_token="token",
|
||||
recipients=[],
|
||||
submitted=False,
|
||||
)
|
||||
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=mock_repo,
|
||||
)
|
||||
|
||||
run_result = node._run()
|
||||
pause_event = next(run_result)
|
||||
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
expected_values = {"user_name": "Jane Doe"}
|
||||
assert pause_event.reason.resolved_default_values == expected_values
|
||||
|
||||
params = mock_repo.create_form.call_args.args[0]
|
||||
assert params.resolved_default_values == expected_values
|
||||
|
||||
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-2",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Provide your name",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
mock_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_repo.get_form.return_value = None
|
||||
mock_repo.create_form.return_value = SimpleNamespace(
|
||||
id="form-2",
|
||||
rendered_content="Provide your name",
|
||||
web_app_token="console-token",
|
||||
recipients=[SimpleNamespace(token="recipient-token")],
|
||||
submitted=False,
|
||||
)
|
||||
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=mock_repo,
|
||||
)
|
||||
|
||||
run_result = node._run()
|
||||
pause_event = next(run_result)
|
||||
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
assert pause_event.reason.form_token == "console-token"
|
||||
|
||||
def test_debugger_debug_mode_overrides_email_recipients(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user-123",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-3",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user-123",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Provide your name",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
delivery_methods=[
|
||||
EmailDeliveryMethod(
|
||||
enabled=True,
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")],
|
||||
),
|
||||
subject="Subject",
|
||||
body="Body",
|
||||
debug_mode=True,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
mock_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_repo.get_form.return_value = None
|
||||
mock_repo.create_form.return_value = SimpleNamespace(
|
||||
id="form-3",
|
||||
rendered_content="Provide your name",
|
||||
web_app_token="token",
|
||||
recipients=[],
|
||||
submitted=False,
|
||||
)
|
||||
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=mock_repo,
|
||||
)
|
||||
|
||||
run_result = node._run()
|
||||
pause_event = next(run_result)
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
|
||||
params = mock_repo.create_form.call_args.args[0]
|
||||
assert len(params.delivery_methods) == 1
|
||||
method = params.delivery_methods[0]
|
||||
assert isinstance(method, EmailDeliveryMethod)
|
||||
assert method.config.debug_mode is True
|
||||
assert method.config.recipients.whole_workspace is False
|
||||
assert len(method.config.recipients.items) == 1
|
||||
recipient = method.config.recipients.items[0]
|
||||
assert isinstance(recipient, MemberRecipient)
|
||||
assert recipient.user_id == "user-123"
|
||||
|
||||
|
||||
class TestValidation:
|
||||
"""Test validation scenarios."""
|
||||
|
||||
def test_invalid_form_input_type(self):
|
||||
"""Test validation with invalid form input type."""
|
||||
with pytest.raises(ValidationError):
|
||||
FormInput(
|
||||
type="invalid-type", # Invalid type
|
||||
output_variable_name="test",
|
||||
)
|
||||
|
||||
def test_invalid_button_style(self):
|
||||
"""Test validation with invalid button style."""
|
||||
with pytest.raises(ValidationError):
|
||||
UserAction(
|
||||
id="test",
|
||||
title="Test",
|
||||
button_style="invalid-style", # Invalid style
|
||||
)
|
||||
|
||||
def test_invalid_timeout_unit(self):
|
||||
"""Test validation with invalid timeout unit."""
|
||||
with pytest.raises(ValidationError):
|
||||
HumanInputNodeData(
|
||||
title="Test",
|
||||
timeout_unit="invalid-unit", # Invalid unit
|
||||
)
|
||||
|
||||
|
||||
class TestHumanInputNodeRenderedContent:
|
||||
"""Tests for rendering submitted content."""
|
||||
|
||||
def test_replaces_outputs_placeholders_after_submission(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Name: {{#$output.name#}}",
|
||||
inputs=[
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="name",
|
||||
)
|
||||
],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
form_repository = InMemoryHumanInputFormRepository()
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
pause_gen = node._run()
|
||||
pause_event = next(pause_gen)
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
with pytest.raises(StopIteration):
|
||||
next(pause_gen)
|
||||
|
||||
form_repository.set_submission(action_id="approve", form_data={"name": "Alice"})
|
||||
|
||||
events = list(node._run())
|
||||
last_event = events[-1]
|
||||
assert isinstance(last_event, StreamCompletedEvent)
|
||||
node_run_result = last_event.node_run_result
|
||||
assert node_run_result.outputs["__rendered_content"] == "Name: Alice"
|
||||
@ -0,0 +1,172 @@
|
||||
import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_events import (
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class _FakeFormRepository:
|
||||
def __init__(self, form):
|
||||
self._form = form
|
||||
|
||||
def get_form(self, *_args, **_kwargs):
|
||||
return self._form
|
||||
|
||||
|
||||
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
|
||||
system_variables = SystemVariable.default()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
config = {
|
||||
"id": "node-1",
|
||||
"type": NodeType.HUMAN_INPUT.value,
|
||||
"data": {
|
||||
"title": "Human Input",
|
||||
"form_content": form_content,
|
||||
"inputs": [
|
||||
{
|
||||
"type": "text_input",
|
||||
"output_variable_name": "name",
|
||||
"default": {"type": "constant", "value": ""},
|
||||
}
|
||||
],
|
||||
"user_actions": [
|
||||
{
|
||||
"id": "Accept",
|
||||
"title": "Approve",
|
||||
"button_style": "default",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
fake_form = SimpleNamespace(
|
||||
id="form-1",
|
||||
rendered_content=form_content,
|
||||
submitted=True,
|
||||
selected_action_id="Accept",
|
||||
submitted_data={"name": "Alice"},
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=naive_utc_now() + datetime.timedelta(days=1),
|
||||
)
|
||||
|
||||
repo = _FakeFormRepository(fake_form)
|
||||
return HumanInputNode(
|
||||
id="node-1",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
|
||||
def _build_timeout_node() -> HumanInputNode:
|
||||
system_variables = SystemVariable.default()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
config = {
|
||||
"id": "node-1",
|
||||
"type": NodeType.HUMAN_INPUT.value,
|
||||
"data": {
|
||||
"title": "Human Input",
|
||||
"form_content": "Please enter your name:\n\n{{#$output.name#}}",
|
||||
"inputs": [
|
||||
{
|
||||
"type": "text_input",
|
||||
"output_variable_name": "name",
|
||||
"default": {"type": "constant", "value": ""},
|
||||
}
|
||||
],
|
||||
"user_actions": [
|
||||
{
|
||||
"id": "Accept",
|
||||
"title": "Approve",
|
||||
"button_style": "default",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
fake_form = SimpleNamespace(
|
||||
id="form-1",
|
||||
rendered_content="content",
|
||||
submitted=False,
|
||||
selected_action_id=None,
|
||||
submitted_data=None,
|
||||
status=HumanInputFormStatus.TIMEOUT,
|
||||
expiration_time=naive_utc_now() - datetime.timedelta(minutes=1),
|
||||
)
|
||||
|
||||
repo = _FakeFormRepository(fake_form)
|
||||
return HumanInputNode(
|
||||
id="node-1",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_node_emits_form_filled_event_before_succeeded():
|
||||
node = _build_node()
|
||||
|
||||
events = list(node.run())
|
||||
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
assert isinstance(events[1], NodeRunHumanInputFormFilledEvent)
|
||||
|
||||
filled_event = events[1]
|
||||
assert filled_event.node_title == "Human Input"
|
||||
assert filled_event.rendered_content.endswith("Alice")
|
||||
assert filled_event.action_id == "Accept"
|
||||
assert filled_event.action_text == "Approve"
|
||||
|
||||
|
||||
def test_human_input_node_emits_timeout_event_before_succeeded():
|
||||
node = _build_timeout_node()
|
||||
|
||||
events = list(node.run())
|
||||
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
assert isinstance(events[1], NodeRunHumanInputFormTimeoutEvent)
|
||||
|
||||
timeout_event = events[1]
|
||||
assert timeout_event.node_title == "Human Input"
|
||||
Reference in New Issue
Block a user