mirror of
https://github.com/langgenius/dify.git
synced 2026-04-22 03:37:44 +08:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from models.enums import MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline:
|
||||
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__(
|
||||
pipeline_module.AdvancedChatAppGenerateTaskPipeline
|
||||
)
|
||||
pipeline._workflow_run_id = "run-1"
|
||||
pipeline._message_id = "message-1"
|
||||
pipeline._workflow_tenant_id = "tenant-1"
|
||||
return pipeline
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
|
||||
|
||||
captured_session: dict[str, mock.Mock] = {}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
session.scalar.return_value = None
|
||||
captured_session["session"] = session
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
session = captured_session["session"]
|
||||
session.add.assert_called_once()
|
||||
content = session.add.call_args.args[0]
|
||||
assert isinstance(content, HumanInputContent)
|
||||
assert content.workflow_run_id == "run-1"
|
||||
assert content.message_id == "message-1"
|
||||
assert content.form_id == "form-1"
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None)
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
called["value"] = True
|
||||
session = mock.Mock()
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
assert called["value"] is False
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
|
||||
|
||||
captured_session: dict[str, mock.Mock] = {}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
session.scalar.return_value = HumanInputContent(
|
||||
workflow_run_id="run-1",
|
||||
message_id="message-1",
|
||||
form_id="form-1",
|
||||
)
|
||||
captured_session["session"] = session
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
session = captured_session["session"]
|
||||
session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = []
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(
|
||||
return_value=SimpleNamespace(
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
),
|
||||
)
|
||||
pipeline._save_message = mock.Mock()
|
||||
message = SimpleNamespace(status=MessageStatus.NORMAL)
|
||||
pipeline._get_message = mock.Mock(return_value=message)
|
||||
pipeline._persist_human_input_extra_content = mock.Mock()
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._message_saved_on_pause = False
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
form_token="token-1",
|
||||
resolved_default_values={},
|
||||
)
|
||||
event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"])
|
||||
|
||||
list(pipeline._handle_workflow_paused_event(event))
|
||||
|
||||
pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1")
|
||||
assert message.status == MessageStatus.PAUSED
|
||||
|
||||
|
||||
def test_resume_appends_chunks_to_paused_answer() -> None:
|
||||
app_config = SimpleNamespace(app_id="app-1", tenant_id="tenant-1", sensitive_word_avoidance=None)
|
||||
application_generate_entity = SimpleNamespace(
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
workflow_run_id="run-1",
|
||||
query="hello",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
inputs={},
|
||||
task_id="task-1",
|
||||
)
|
||||
queue_manager = SimpleNamespace(graph_runtime_state=None)
|
||||
conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat")
|
||||
message = SimpleNamespace(
|
||||
id="message-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
query="hello",
|
||||
answer="before",
|
||||
status=MessageStatus.PAUSED,
|
||||
)
|
||||
user = EndUser()
|
||||
user.id = "user-1"
|
||||
user.session_id = "session-1"
|
||||
workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={})
|
||||
|
||||
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=True,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=SimpleNamespace(),
|
||||
)
|
||||
|
||||
pipeline._get_message = mock.Mock(return_value=message)
|
||||
pipeline._recorded_files = []
|
||||
|
||||
list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after")))
|
||||
pipeline._save_message(session=mock.Mock())
|
||||
|
||||
assert message.answer == "beforeafter"
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
@ -0,0 +1,87 @@
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _build_converter():
|
||||
system_variables = SystemVariable(
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
app_entity = SimpleNamespace(
|
||||
task_id="task-1",
|
||||
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
files=[],
|
||||
inputs={},
|
||||
workflow_execution_id="run-1",
|
||||
call_depth=0,
|
||||
)
|
||||
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=app_entity,
|
||||
user=account,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_form_filled_stream_response_contains_rendered_content():
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
queue_event = QueueHumanInputFormFilledEvent(
|
||||
node_execution_id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
node_title="Human Input",
|
||||
rendered_content="# Title\nvalue",
|
||||
action_id="Approve",
|
||||
action_text="Approve",
|
||||
)
|
||||
|
||||
resp = converter.human_input_form_filled_to_stream_response(event=queue_event, task_id="task-1")
|
||||
|
||||
assert resp.workflow_run_id == "run-1"
|
||||
assert resp.data.node_id == "node-1"
|
||||
assert resp.data.node_title == "Human Input"
|
||||
assert resp.data.rendered_content.startswith("# Title")
|
||||
assert resp.data.action_id == "Approve"
|
||||
|
||||
|
||||
def test_human_input_form_timeout_stream_response_contains_timeout_metadata():
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
queue_event = QueueHumanInputFormTimeoutEvent(
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
node_title="Human Input",
|
||||
expiration_time=datetime(2025, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
resp = converter.human_input_form_timeout_to_stream_response(event=queue_event, task_id="task-1")
|
||||
|
||||
assert resp.workflow_run_id == "run-1"
|
||||
assert resp.data.node_id == "node-1"
|
||||
assert resp.data.node_title == "Human Input"
|
||||
assert resp.data.expiration_time == 1735689600
|
||||
@ -0,0 +1,56 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _build_converter() -> WorkflowResponseConverter:
|
||||
"""Construct a minimal WorkflowResponseConverter for testing."""
|
||||
system_variables = SystemVariable(
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
app_entity = SimpleNamespace(
|
||||
task_id="task-1",
|
||||
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
files=[],
|
||||
inputs={},
|
||||
workflow_execution_id="run-1",
|
||||
call_depth=0,
|
||||
)
|
||||
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=app_entity,
|
||||
user=account,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_start_stream_response_carries_resumption_reason():
|
||||
converter = _build_converter()
|
||||
resp = converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.RESUMPTION,
|
||||
)
|
||||
assert resp.data.reason is WorkflowStartReason.RESUMPTION
|
||||
|
||||
|
||||
def test_workflow_start_stream_response_carries_initial_reason():
|
||||
converter = _build_converter()
|
||||
resp = converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
assert resp.data.reason is WorkflowStartReason.INITIAL
|
||||
@ -23,6 +23,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -124,7 +125,12 @@ class TestWorkflowResponseConverter:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -160,7 +166,12 @@ class TestWorkflowResponseConverter:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -191,7 +202,12 @@ class TestWorkflowResponseConverter:
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -225,7 +241,12 @@ class TestWorkflowResponseConverter:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -261,7 +282,12 @@ class TestWorkflowResponseConverter:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -400,6 +426,7 @@ class TestWorkflowResponseConverterServiceApiTruncation:
|
||||
task_id="test-task-id",
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
return converter
|
||||
|
||||
|
||||
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.task_pipeline import message_cycle_manager
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
def _make_app_config() -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_id="workflow-id",
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity:
|
||||
return AdvancedChatAppGenerateEntity(
|
||||
task_id="task-id",
|
||||
app_config=app_config,
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
parent_message_id=None,
|
||||
user_id="user-id",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
workflow_run_id="workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_db_session(monkeypatch):
|
||||
session = MagicMock()
|
||||
|
||||
def refresh_side_effect(obj):
|
||||
if isinstance(obj, Conversation) and obj.id is None:
|
||||
obj.id = "generated-conversation-id"
|
||||
if isinstance(obj, Message) and obj.id is None:
|
||||
obj.id = "generated-message-id"
|
||||
|
||||
session.refresh.side_effect = refresh_side_effect
|
||||
session.add.return_value = None
|
||||
session.commit.return_value = None
|
||||
|
||||
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
|
||||
return session
|
||||
|
||||
|
||||
def test_init_generate_records_sets_conversation_metadata():
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
|
||||
|
||||
def test_init_generate_records_marks_existing_conversation():
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
|
||||
existing_conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
override_model_configs=None,
|
||||
model_id=None,
|
||||
mode=app_config.app_mode.value,
|
||||
name="existing",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
from_source="api",
|
||||
from_end_user_id="user-id",
|
||||
from_account_id=None,
|
||||
)
|
||||
existing_conversation.id = "existing-conversation-id"
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation)
|
||||
|
||||
assert entity.conversation_id == "existing-conversation-id"
|
||||
assert conversation is existing_conversation
|
||||
assert entity.is_new_conversation is False
|
||||
|
||||
|
||||
def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
entity.conversation_id = "existing-conversation-id"
|
||||
entity.is_new_conversation = True
|
||||
entity.extras = {"auto_generate_conversation_name": True}
|
||||
|
||||
captured = {}
|
||||
|
||||
class DummyThread:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
def fake_thread(**kwargs):
|
||||
thread = DummyThread(**kwargs)
|
||||
captured["thread"] = thread
|
||||
return thread
|
||||
|
||||
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
|
||||
|
||||
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
|
||||
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")
|
||||
|
||||
assert thread is captured["thread"]
|
||||
assert thread.started is True
|
||||
assert entity.is_new_conversation is False
|
||||
@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AppAdditionalFeatures,
|
||||
EasyUIBasedAppConfig,
|
||||
EasyUIBasedAppModelConfigFrom,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
class DummyModelConf:
|
||||
def __init__(self, provider: str = "mock-provider", model: str = "mock-model") -> None:
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
|
||||
|
||||
class DummyCompletionGenerateEntity:
|
||||
__slots__ = ("app_config", "invoke_from", "user_id", "query", "inputs", "files", "model_conf")
|
||||
app_config: EasyUIBasedAppConfig
|
||||
invoke_from: InvokeFrom
|
||||
user_id: str
|
||||
query: str
|
||||
inputs: dict
|
||||
files: list
|
||||
model_conf: DummyModelConf
|
||||
|
||||
def __init__(self, app_config: EasyUIBasedAppConfig) -> None:
|
||||
self.app_config = app_config
|
||||
self.invoke_from = InvokeFrom.WEB_APP
|
||||
self.user_id = "user-id"
|
||||
self.query = "hello"
|
||||
self.inputs = {}
|
||||
self.files = []
|
||||
self.model_conf = DummyModelConf()
|
||||
|
||||
|
||||
def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig:
|
||||
return EasyUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
|
||||
app_model_config_id="model-config-id",
|
||||
app_model_config_dict={},
|
||||
model=ModelConfigEntity(provider="mock-provider", model="mock-model", mode="chat"),
|
||||
prompt_template=PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="Hello",
|
||||
),
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_chat_generate_entity(app_config: EasyUIBasedAppConfig) -> ChatAppGenerateEntity:
|
||||
return ChatAppGenerateEntity.model_construct(
|
||||
task_id="task-id",
|
||||
app_config=app_config,
|
||||
model_conf=DummyModelConf(),
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
parent_message_id=None,
|
||||
user_id="user-id",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
call_depth=0,
|
||||
trace_manager=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_db_session(monkeypatch):
|
||||
session = MagicMock()
|
||||
|
||||
def refresh_side_effect(obj):
|
||||
if isinstance(obj, Conversation) and obj.id is None:
|
||||
obj.id = "generated-conversation-id"
|
||||
if isinstance(obj, Message) and obj.id is None:
|
||||
obj.id = "generated-message-id"
|
||||
|
||||
session.refresh.side_effect = refresh_side_effect
|
||||
session.add.return_value = None
|
||||
session.commit.return_value = None
|
||||
|
||||
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
|
||||
return session
|
||||
|
||||
|
||||
def test_init_generate_records_skips_conversation_fields_for_non_conversation_entity():
|
||||
app_config = _make_app_config(AppMode.COMPLETION)
|
||||
entity = DummyCompletionGenerateEntity(app_config=app_config)
|
||||
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
conversation, message = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
assert message.id == "generated-message-id"
|
||||
assert hasattr(entity, "conversation_id") is False
|
||||
assert hasattr(entity, "is_new_conversation") is False
|
||||
|
||||
|
||||
def test_init_generate_records_sets_conversation_fields_for_chat_entity():
|
||||
app_config = _make_app_config(AppMode.CHAT)
|
||||
entity = _make_chat_generate_entity(app_config)
|
||||
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
287
api/tests/unit_tests/core/app/apps/test_pause_resume.py
Normal file
287
api/tests/unit_tests/core/app/apps/test_pause_resume.py
Normal file
@ -0,0 +1,287 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
API_DIR = str(Path(__file__).resolve().parents[5])
|
||||
if API_DIR not in sys.path:
|
||||
sys.path.insert(0, API_DIR)
|
||||
|
||||
import core.workflow.nodes.human_input.entities # noqa: F401
|
||||
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
if "core.ops.ops_trace_manager" not in sys.modules:
|
||||
ops_stub = ModuleType("core.ops.ops_trace_manager")
|
||||
|
||||
class _StubTraceQueueManager:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
ops_stub.TraceQueueManager = _StubTraceQueueManager
|
||||
sys.modules["core.ops.ops_trace_manager"] = ops_stub
|
||||
|
||||
|
||||
class _StubToolNodeData(BaseNodeData):
|
||||
pause_on: bool = False
|
||||
|
||||
|
||||
class _StubToolNode(Node[_StubToolNodeData]):
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def init_node_data(self, data):
|
||||
self._node_data = _StubToolNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self):
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self):
|
||||
if self.node_data.pause_on:
|
||||
yield PauseRequestedEvent(reason=SchedulingPause(message="test pause"))
|
||||
return
|
||||
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"value": f"{self.id}-done"},
|
||||
)
|
||||
yield self._convert_node_run_result_to_graph_node_event(result)
|
||||
|
||||
|
||||
def _patch_tool_node(mocker):
|
||||
original_create_node = DifyNodeFactory.create_node
|
||||
|
||||
def _patched_create_node(self, node_config: dict[str, object]) -> Node:
|
||||
node_data = node_config.get("data", {})
|
||||
if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value:
|
||||
return _StubToolNode(
|
||||
id=str(node_config["id"]),
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
return original_create_node(self, node_config)
|
||||
|
||||
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
|
||||
|
||||
|
||||
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
|
||||
node_data = data.model_dump()
|
||||
node_data["type"] = node_type.value
|
||||
return node_data
|
||||
|
||||
|
||||
def _build_graph_config(*, pause_on: str | None) -> dict[str, object]:
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
tool_data_a = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_a")
|
||||
tool_data_b = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_b")
|
||||
tool_data_c = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_c")
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[OutputVariableEntity(variable="result", value_selector=["tool_c", "value"])],
|
||||
desc=None,
|
||||
)
|
||||
|
||||
nodes = [
|
||||
{"id": "start", "data": _node_data(NodeType.START, start_data)},
|
||||
{"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)},
|
||||
{"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)},
|
||||
{"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)},
|
||||
{"id": "end", "data": _node_data(NodeType.END, end_data)},
|
||||
]
|
||||
edges = [
|
||||
{"source": "start", "target": "tool_a"},
|
||||
{"source": "tool_a", "target": "tool_b"},
|
||||
{"source": "tool_b", "target": "tool_c"},
|
||||
{"source": "tool_c", "target": "end"},
|
||||
]
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph:
|
||||
graph_config = _build_graph_config(pause_on=pause_on)
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="service-api",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
return Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.system_variables.workflow_execution_id = run_id
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
|
||||
command_channel = InMemoryChannel()
|
||||
graph = _build_graph(runtime_state, pause_on=pause_on)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
events: list[GraphEngineEvent] = []
|
||||
for event in engine.run():
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
|
||||
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
|
||||
|
||||
|
||||
def test_workflow_app_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("paused-run")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = wf_app_gen_module.WorkflowAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
graph_runtime_state=resumed_state,
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
|
||||
|
||||
def test_advanced_chat_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("adv-baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("adv-paused")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = adv_app_gen_module.AdvancedChatAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv"),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
graph_runtime_state=resumed_state,
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
|
||||
|
||||
def test_resume_emits_resumption_start_reason(mocker) -> None:
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
paused_state = _build_runtime_state("resume-reason")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
initial_start = next(event for event in paused_events if isinstance(event, GraphRunStartedEvent))
|
||||
assert initial_start.reason == WorkflowStartReason.INITIAL
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(paused_state.dumps())
|
||||
resumed_events = _run_with_optional_pause(resumed_state, pause_on=None)
|
||||
resume_start = next(event for event in resumed_events if isinstance(event, GraphRunStartedEvent))
|
||||
assert resume_start.reason == WorkflowStartReason.RESUMPTION
|
||||
80
api/tests/unit_tests/core/app/apps/test_streaming_utils.py
Normal file
80
api/tests/unit_tests/core/app/apps/test_streaming_utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class FakeSubscription:
|
||||
def __init__(self, message_queue: queue.Queue[bytes], state: dict[str, bool]) -> None:
|
||||
self._queue = message_queue
|
||||
self._state = state
|
||||
self._closed = False
|
||||
|
||||
def __enter__(self):
|
||||
self._state["subscribed"] = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
if self._closed:
|
||||
return None
|
||||
try:
|
||||
if timeout is None:
|
||||
return self._queue.get()
|
||||
return self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
|
||||
class FakeTopic:
|
||||
def __init__(self) -> None:
|
||||
self._queue: queue.Queue[bytes] = queue.Queue()
|
||||
self._state = {"subscribed": False}
|
||||
|
||||
def subscribe(self) -> FakeSubscription:
|
||||
return FakeSubscription(self._queue, self._state)
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._queue.put(payload)
|
||||
|
||||
@property
|
||||
def subscribed(self) -> bool:
|
||||
return self._state["subscribed"]
|
||||
|
||||
|
||||
def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
|
||||
topic = FakeTopic()
|
||||
|
||||
def fake_get_response_topic(cls, app_mode, workflow_run_id):
|
||||
return topic
|
||||
|
||||
monkeypatch.setattr(MessageBasedAppGenerator, "get_response_topic", classmethod(fake_get_response_topic))
|
||||
|
||||
def on_subscribe() -> None:
|
||||
assert topic.subscribed is True
|
||||
event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
|
||||
topic.publish(json.dumps(event).encode())
|
||||
|
||||
generator = MessageBasedAppGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
"workflow-run-id",
|
||||
idle_timeout=0.5,
|
||||
on_subscribe=on_subscribe,
|
||||
)
|
||||
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
event = next(generator)
|
||||
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
@ -1,3 +1,6 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
|
||||
|
||||
@ -17,3 +20,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
|
||||
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
|
||||
|
||||
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
|
||||
|
||||
def test_resume_delegates_to_generate(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")
|
||||
|
||||
application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger")
|
||||
runtime_state = MagicMock(name="runtime-state")
|
||||
pause_config = MagicMock(name="pause-config")
|
||||
|
||||
result = generator.resume(
|
||||
app_model=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
graph_runtime_state=runtime_state,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=("layer",),
|
||||
pause_state_config=pause_config,
|
||||
variable_loader=MagicMock(),
|
||||
)
|
||||
|
||||
assert result == "ok"
|
||||
mock_generate.assert_called_once()
|
||||
kwargs = mock_generate.call_args.kwargs
|
||||
assert kwargs["graph_runtime_state"] is runtime_state
|
||||
assert kwargs["pause_state_config"] is pause_config
|
||||
assert kwargs["streaming"] is False
|
||||
assert kwargs["invoke_from"] == "debugger"
|
||||
|
||||
|
||||
def test_generate_appends_pause_layer_and_forwards_state(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
mock_queue_manager = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager)
|
||||
|
||||
fake_current_app = MagicMock()
|
||||
fake_current_app._get_current_object.return_value = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
|
||||
return_value="converted",
|
||||
)
|
||||
mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock())
|
||||
|
||||
pause_layer = MagicMock(name="pause-layer")
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.PauseStatePersistenceLayer",
|
||||
return_value=pause_layer,
|
||||
)
|
||||
|
||||
dummy_session = MagicMock()
|
||||
dummy_session.close = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session)
|
||||
|
||||
worker_kwargs: dict[str, object] = {}
|
||||
|
||||
class DummyThread:
|
||||
def __init__(self, target, kwargs):
|
||||
worker_kwargs["target"] = target
|
||||
worker_kwargs["kwargs"] = kwargs
|
||||
|
||||
def start(self):
|
||||
return None
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread)
|
||||
|
||||
app_model = SimpleNamespace(mode="workflow")
|
||||
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf")
|
||||
application_generate_entity = SimpleNamespace(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from="service-api",
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
stream=True,
|
||||
workflow_execution_id="run",
|
||||
)
|
||||
|
||||
graph_runtime_state = MagicMock()
|
||||
|
||||
result = generator._generate(
|
||||
app_model=app_model,
|
||||
workflow=MagicMock(),
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from="service-api",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
graph_engine_layers=("base-layer",),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"),
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer)
|
||||
assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state
|
||||
|
||||
|
||||
def test_resume_path_runs_worker_with_runtime_state(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
runtime_state = MagicMock(name="runtime-state")
|
||||
|
||||
pause_layer = MagicMock(name="pause-layer")
|
||||
mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager)
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="raw-response")
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
|
||||
side_effect=lambda response, invoke_from: response,
|
||||
)
|
||||
|
||||
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
|
||||
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1"
|
||||
)
|
||||
end_user = SimpleNamespace(session_id="end-user-session")
|
||||
app_record = SimpleNamespace(id="app")
|
||||
|
||||
session = MagicMock()
|
||||
session.__enter__.return_value = session
|
||||
session.__exit__.return_value = False
|
||||
session.scalar.side_effect = [workflow, end_user, app_record]
|
||||
mocker.patch("core.app.apps.workflow.app_generator.session_factory", return_value=session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
|
||||
def runner_ctor(**kwargs):
|
||||
assert kwargs["graph_runtime_state"] is runtime_state
|
||||
return runner_instance
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor)
|
||||
|
||||
class ImmediateThread:
|
||||
def __init__(self, target, kwargs):
|
||||
target(**kwargs)
|
||||
|
||||
def start(self):
|
||||
return None
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner")
|
||||
|
||||
app_model = SimpleNamespace(mode="workflow")
|
||||
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow")
|
||||
application_generate_entity = SimpleNamespace(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from="service-api",
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
stream=True,
|
||||
workflow_execution_id="run",
|
||||
trace_manager=MagicMock(),
|
||||
)
|
||||
|
||||
result = generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
graph_runtime_state=runtime_state,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
|
||||
assert result == "raw-response"
|
||||
runner_instance.run.assert_called_once()
|
||||
queue_manager.graph_runtime_state = runtime_state
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
|
||||
|
||||
class _DummyQueueManager:
|
||||
def __init__(self):
|
||||
self.published = []
|
||||
|
||||
def publish(self, event, _from):
|
||||
self.published.append(event)
|
||||
|
||||
|
||||
class _DummyRuntimeState:
|
||||
def get_paused_nodes(self):
|
||||
return ["node-1"]
|
||||
|
||||
|
||||
class _DummyGraphEngine:
|
||||
def __init__(self):
|
||||
self.graph_runtime_state = _DummyRuntimeState()
|
||||
|
||||
|
||||
class _DummyWorkflowEntry:
|
||||
def __init__(self):
|
||||
self.graph_engine = _DummyGraphEngine()
|
||||
|
||||
|
||||
def test_handle_pause_event_enqueues_email_task(monkeypatch: pytest.MonkeyPatch):
|
||||
queue_manager = _DummyQueueManager()
|
||||
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app-id")
|
||||
workflow_entry = _DummyWorkflowEntry()
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-123",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-1",
|
||||
node_title="Review",
|
||||
)
|
||||
event = GraphRunPausedEvent(reasons=[reason], outputs={})
|
||||
|
||||
email_task = MagicMock()
|
||||
monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", email_task)
|
||||
|
||||
runner._handle_event(workflow_entry, event)
|
||||
|
||||
email_task.apply_async.assert_called_once()
|
||||
kwargs = email_task.apply_async.call_args.kwargs["kwargs"]
|
||||
assert kwargs["form_id"] == "form-123"
|
||||
assert kwargs["node_title"] == "Review"
|
||||
|
||||
assert any(isinstance(evt, QueueWorkflowPausedEvent) for evt in queue_manager.published)
|
||||
183
api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
Normal file
183
api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
Normal file
@ -0,0 +1,183 @@
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.common import workflow_response_converter
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.account import Account
|
||||
|
||||
|
||||
class _RecordingWorkflowAppRunner(WorkflowAppRunner):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.published_events = []
|
||||
|
||||
def _publish_event(self, event):
|
||||
self.published_events.append(event)
|
||||
|
||||
|
||||
class _FakeRuntimeState:
|
||||
def get_paused_nodes(self):
|
||||
return ["node-pause-1"]
|
||||
|
||||
|
||||
def _build_runner():
|
||||
app_entity = SimpleNamespace(
|
||||
app_config=SimpleNamespace(app_id="app-id"),
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
single_iteration_run=None,
|
||||
single_loop_run=None,
|
||||
workflow_execution_id="run-id",
|
||||
user_id="user-id",
|
||||
)
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={},
|
||||
tenant_id="tenant-id",
|
||||
environment_variables={},
|
||||
id="workflow-id",
|
||||
)
|
||||
queue_manager = SimpleNamespace(publish=lambda event, pub_from: None)
|
||||
return _RecordingWorkflowAppRunner(
|
||||
application_generate_entity=app_entity,
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="sys-user",
|
||||
root_node_id=None,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=(),
|
||||
graph_runtime_state=None,
|
||||
)
|
||||
|
||||
|
||||
def test_graph_run_paused_event_emits_queue_pause_event():
|
||||
runner = _build_runner()
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-human",
|
||||
node_title="Human Step",
|
||||
form_token="tok",
|
||||
)
|
||||
event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"})
|
||||
workflow_entry = SimpleNamespace(
|
||||
graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()),
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry, event)
|
||||
|
||||
assert len(runner.published_events) == 1
|
||||
queue_event = runner.published_events[0]
|
||||
assert isinstance(queue_event, QueueWorkflowPausedEvent)
|
||||
assert queue_event.reasons == [reason]
|
||||
assert queue_event.outputs == {"foo": "bar"}
|
||||
assert queue_event.paused_nodes == ["node-pause-1"]
|
||||
|
||||
|
||||
def _build_converter():
|
||||
application_generate_entity = SimpleNamespace(
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
|
||||
)
|
||||
system_variables = SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
)
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "account-id"
|
||||
user.name = "Tester"
|
||||
user.email = "tester@example.com"
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.MonkeyPatch):
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
class _FakeSession:
|
||||
def execute(self, _stmt):
|
||||
return [("form-1", expiration_time)]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
|
||||
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="Rendered",
|
||||
inputs=[
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
|
||||
],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
node_id="node-id",
|
||||
node_title="Human Step",
|
||||
form_token="token",
|
||||
)
|
||||
queue_event = QueueWorkflowPausedEvent(
|
||||
reasons=[reason],
|
||||
outputs={"answer": "value"},
|
||||
paused_nodes=["node-id"],
|
||||
)
|
||||
|
||||
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
|
||||
responses = converter.workflow_pause_to_stream_response(
|
||||
event=queue_event,
|
||||
task_id="task",
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
|
||||
pause_resp = responses[-1]
|
||||
assert pause_resp.workflow_run_id == "run-id"
|
||||
assert pause_resp.data.paused_nodes == ["node-id"]
|
||||
assert pause_resp.data.outputs == {}
|
||||
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
|
||||
assert pause_resp.data.reasons[0]["display_in_ui"] is True
|
||||
|
||||
assert isinstance(responses[0], HumanInputRequiredResponse)
|
||||
hi_resp = responses[0]
|
||||
assert hi_resp.data.form_id == "form-1"
|
||||
assert hi_resp.data.node_id == "node-id"
|
||||
assert hi_resp.data.node_title == "Human Step"
|
||||
assert hi_resp.data.inputs[0].output_variable_name == "field"
|
||||
assert hi_resp.data.actions[0].id == "approve"
|
||||
assert hi_resp.data.display_in_ui is True
|
||||
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
|
||||
@ -0,0 +1,96 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueWorkflowStartedEvent
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _build_workflow_app_config() -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
|
||||
|
||||
def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity:
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id="task-id",
|
||||
app_config=_build_workflow_app_config(),
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-id",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_execution_id=run_id,
|
||||
)
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(workflow_execution_id=run_id),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _noop_session():
|
||||
yield MagicMock()
|
||||
|
||||
|
||||
def _build_pipeline(run_id: str) -> WorkflowAppGenerateTaskPipeline:
|
||||
queue_manager = MagicMock(spec=AppQueueManager)
|
||||
queue_manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
queue_manager.graph_runtime_state = _build_runtime_state(run_id)
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-id"
|
||||
workflow.features_dict = {}
|
||||
user = Account(name="user", email="user@example.com")
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=_build_generate_entity(run_id),
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=False,
|
||||
draft_var_saver_factory=MagicMock(),
|
||||
)
|
||||
pipeline._database_session = _noop_session
|
||||
return pipeline
|
||||
|
||||
|
||||
def test_workflow_app_log_saved_only_on_initial_start() -> None:
|
||||
run_id = "run-initial"
|
||||
pipeline = _build_pipeline(run_id)
|
||||
pipeline._save_workflow_app_log = MagicMock()
|
||||
|
||||
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.INITIAL)
|
||||
list(pipeline._handle_workflow_started_event(event))
|
||||
|
||||
pipeline._save_workflow_app_log.assert_called_once()
|
||||
_, kwargs = pipeline._save_workflow_app_log.call_args
|
||||
assert kwargs["workflow_run_id"] == run_id
|
||||
assert pipeline._workflow_execution_id == run_id
|
||||
|
||||
|
||||
def test_workflow_app_log_skipped_on_resumption_start() -> None:
|
||||
run_id = "run-resume"
|
||||
pipeline = _build_pipeline(run_id)
|
||||
pipeline._save_workflow_app_log = MagicMock()
|
||||
|
||||
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.RESUMPTION)
|
||||
list(pipeline._handle_workflow_started_event(event))
|
||||
|
||||
pipeline._save_workflow_app_log.assert_not_called()
|
||||
assert pipeline._workflow_execution_id == run_id
|
||||
@ -0,0 +1,143 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import (
|
||||
WorkflowResumptionContext,
|
||||
_AdvancedChatAppGenerateEntityWrapper,
|
||||
_WorkflowGenerateEntityWrapper,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TraceQueueManagerStub(TraceQueueManager):
|
||||
"""Minimal TraceQueueManager stub that avoids Flask dependencies."""
|
||||
|
||||
def __init__(self):
|
||||
# Skip parent initialization to avoid starting timers or accessing Flask globals.
|
||||
pass
|
||||
|
||||
|
||||
def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=app_mode,
|
||||
workflow_id=f"{app_mode.value}-workflow-id",
|
||||
)
|
||||
|
||||
|
||||
def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity:
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id="workflow-task",
|
||||
app_config=_build_workflow_app_config(AppMode.WORKFLOW),
|
||||
inputs={"topic": "serialization"},
|
||||
files=[],
|
||||
user_id="user-workflow",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=1,
|
||||
trace_manager=trace_manager,
|
||||
workflow_execution_id="workflow-exec-id",
|
||||
extras={"external_trace_id": "trace-id"},
|
||||
)
|
||||
|
||||
|
||||
def _create_advanced_chat_generate_entity(
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> AdvancedChatAppGenerateEntity:
|
||||
return AdvancedChatAppGenerateEntity(
|
||||
task_id="advanced-task",
|
||||
app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT),
|
||||
conversation_id="conversation-id",
|
||||
inputs={"topic": "roundtrip"},
|
||||
files=[],
|
||||
user_id="user-advanced",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
query="Explain serialization",
|
||||
extras={"auto_generate_conversation_name": True},
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id="workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager():
|
||||
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
|
||||
serialized = entity.model_dump_json()
|
||||
payload = json.loads(serialized)
|
||||
|
||||
assert "trace_manager" not in payload
|
||||
|
||||
restored = WorkflowAppGenerateEntity.model_validate_json(serialized)
|
||||
|
||||
assert restored.model_dump() == entity.model_dump()
|
||||
assert restored.trace_manager is None
|
||||
|
||||
|
||||
def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager():
|
||||
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
|
||||
serialized = entity.model_dump_json()
|
||||
payload = json.loads(serialized)
|
||||
|
||||
assert "trace_manager" not in payload
|
||||
|
||||
restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized)
|
||||
|
||||
assert restored.model_dump() == entity.model_dump()
|
||||
assert restored.trace_manager is None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResumptionContextCase:
|
||||
name: str
|
||||
context_factory: Callable[[], tuple[WorkflowResumptionContext, type]]
|
||||
|
||||
|
||||
def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]:
|
||||
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
context = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=json.dumps({"state": "workflow"}),
|
||||
generate_entity=_WorkflowGenerateEntityWrapper(entity=entity),
|
||||
)
|
||||
return context, WorkflowAppGenerateEntity
|
||||
|
||||
|
||||
def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]:
|
||||
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
context = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=json.dumps({"state": "advanced"}),
|
||||
generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity),
|
||||
)
|
||||
return context, AdvancedChatAppGenerateEntity
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"),
|
||||
pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"),
|
||||
],
|
||||
)
|
||||
def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase):
|
||||
context, expected_type = case.context_factory()
|
||||
|
||||
serialized = context.dumps()
|
||||
restored = WorkflowResumptionContext.loads(serialized)
|
||||
|
||||
assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state
|
||||
entity = restored.get_generate_entity()
|
||||
assert isinstance(entity, expected_type)
|
||||
assert entity.model_dump() == context.get_generate_entity().model_dump()
|
||||
assert entity.trace_manager is None
|
||||
Reference in New Issue
Block a user