From 08175ab32ac3dc4f72a0aca3514ec22273668efe Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Fri, 5 Dec 2025 02:44:04 +0800 Subject: [PATCH] feat: support variable resolution, fix linting --- .../advanced_chat/generate_task_pipeline.py | 3 +- .../repositories/human_input_reposotiry.py | 9 +- api/core/workflow/entities/pause_reason.py | 19 ++++- .../workflow/nodes/human_input/entities.py | 5 +- .../nodes/human_input/human_input_node.py | 24 +++++- .../human_input_form_repository.py | 6 ++ api/core/workflow/workflow_type_encoder.py | 12 +-- api/scripts/workflow_event_subscriber.py | 2 +- .../test_human_input_form_repository_impl.py | 37 ++++++++ .../app/apps/test_workflow_app_generator.py | 21 +---- .../app/entities/test_app_invoke_entities.py | 4 +- .../workflow/entities/test_pause_reason.py | 16 +++- .../test_human_input_pause_multi_branch.py | 6 +- .../nodes/human_input/test_entities.py | 85 +++++++++++++++++++ 14 files changed, 207 insertions(+), 42 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index aeb950028e..55be001876 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -537,8 +537,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): event=event, task_id=self._application_generate_entity.task_id, ) - for response in responses: - yield response + yield from responses self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) def _handle_workflow_failed_event( diff --git a/api/core/repositories/human_input_reposotiry.py b/api/core/repositories/human_input_reposotiry.py index d4df1415d2..f99bc877bc 100644 --- a/api/core/repositories/human_input_reposotiry.py +++ b/api/core/repositories/human_input_reposotiry.py @@ -61,9 +61,7 @@ class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): @property def token(self) -> str: if self._recipient_model.access_token is None: - raise AssertionError( - f"access_token should not be None for recipient {self._recipient_model.id}" - ) + raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}") return self._recipient_model.access_token @@ -309,6 +307,7 @@ class HumanInputFormRepositoryImpl: rendered_content=params.rendered_content, timeout=form_config.timeout, timeout_unit=form_config.timeout_unit, + placeholder_values=dict(params.resolved_placeholder_values), ) form_model = HumanInputForm( id=form_id, @@ -345,9 +344,7 @@ class HumanInputFormRepositoryImpl: if form_model is None: return None - recipient_query = select(HumanInputFormRecipient).where( - HumanInputFormRecipient.form_id == form_model.id - ) + recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id) recipient_models = session.scalars(recipient_query).all() return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) diff --git a/api/core/workflow/entities/pause_reason.py b/api/core/workflow/entities/pause_reason.py index a9dbfd22d2..f1b7eff8dd 100644 --- a/api/core/workflow/entities/pause_reason.py +++ b/api/core/workflow/entities/pause_reason.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from enum import StrEnum, auto -from typing import Annotated, Literal, TypeAlias +from typing import Annotated, Any, Literal, TypeAlias from pydantic import BaseModel, Field @@ -19,6 +20,22 @@ class HumanInputRequired(BaseModel): actions: list[UserAction] = Field(default_factory=list) node_id: str node_title: str + + # The `resolved_placeholder_values` stores the resolved values of variable placeholders. It's a mapping from + # `output_variable_name` to their resolved values. + # + # For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its + # selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable + # `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The + # `resolved_placeholder_values` is `{"name": "John"}`. + # + # Only form inputs with placeholder type `VARIABLE` will be resolved and stored in `resolved_placeholder_values`. + resolved_placeholder_values: Mapping[str, Any] = Field(default_factory=dict) + + # The `web_app_form_token` is the token used to submit the form via webapp. It corresponds to + # `HumanInputFormRecipient.access_token`. + # + # This field is `None` if webapp delivery is not set. web_app_form_token: str | None = None diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/core/workflow/nodes/human_input/entities.py index ffa6a74ef5..fd9ccdb882 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/core/workflow/nodes/human_input/entities.py @@ -8,7 +8,7 @@ import uuid from collections.abc import Mapping, Sequence from datetime import datetime, timedelta from enum import StrEnum -from typing import Annotated, Literal, Optional, Self +from typing import Annotated, Any, Literal, Optional, Self from pydantic import BaseModel, Field, field_validator, model_validator @@ -278,3 +278,6 @@ class FormDefinition(BaseModel): timeout: int timeout_unit: TimeoutUnit + + # this is used to store the values of the placeholders + placeholder_values: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index 3ca1e81a4d..d2403d780d 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -13,9 +13,10 @@ from core.workflow.repositories.human_input_form_repository import ( HumanInputFormEntity, HumanInputFormRepository, ) +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_database import db -from .entities import HumanInputNodeData +from .entities import HumanInputNodeData, PlaceholderType if TYPE_CHECKING: from core.workflow.entities.graph_init_params import GraphInitParams @@ -130,8 +131,27 @@ class HumanInputNode(Node[HumanInputNodeData]): pause_requested_event = PauseRequestedEvent(reason=required_event) return pause_requested_event + def _resolve_inputs(self) -> Mapping[str, Any]: + variable_pool = self.graph_runtime_state.variable_pool + resolved_inputs = {} + for input in self._node_data.inputs: + if (placeholder := input.placeholder) is None: + continue + if placeholder.type == PlaceholderType.CONSTANT: + continue + placeholder_value = variable_pool.get(placeholder.selector) + if placeholder_value is None: + # TODO: How should we handle this? + continue + resolved_inputs[input.output_variable_name] = ( + WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(placeholder_value.value) + ) + + return resolved_inputs + def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: node_data = self._node_data + resolved_placeholder_values = self._resolve_inputs() return HumanInputRequired( form_id=form_entity.id, form_content=form_entity.rendered_content, @@ -140,6 +160,7 @@ class HumanInputNode(Node[HumanInputNodeData]): node_id=self.id, node_title=node_data.title, web_app_form_token=form_entity.web_app_token, + resolved_placeholder_values=resolved_placeholder_values, ) def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult: @@ -149,6 +170,7 @@ class HumanInputNode(Node[HumanInputNodeData]): node_id=self.id, form_config=self._node_data, rendered_content=self._render_form_content(), + resolved_placeholder_values=self._resolve_inputs(), ) form_entity = self._form_repository.create_form(params) # Create human input required event diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/core/workflow/repositories/human_input_form_repository.py index 271c749eee..6050e86c8d 100644 --- a/api/core/workflow/repositories/human_input_form_repository.py +++ b/api/core/workflow/repositories/human_input_form_repository.py @@ -28,6 +28,12 @@ class FormCreateParams: form_config: HumanInputNodeData rendered_content: str + # resolved_placeholder_values saves the values for placeholders with + # type = VARIABLE. + # + # For type = CONSTANT, the value is not stored inside `resolved_placeholder_values` + resolved_placeholder_values: Mapping[str, Any] + class HumanInputFormEntity(abc.ABC): @property diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index 5456043ccd..f1f549e1f8 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -15,12 +15,14 @@ class WorkflowRuntimeTypeConverter: def to_json_encodable(self, value: None) -> None: ... def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: - result = self._to_json_encodable_recursive(value) + """Convert runtime values to JSON-serializable structures.""" + + result = self.value_to_json_encodable_recursive(value) if isinstance(result, Mapping) or result is None: return result return {} - def _to_json_encodable_recursive(self, value: Any): + def value_to_json_encodable_recursive(self, value: Any): if value is None: return value if isinstance(value, (bool, int, str, float)): @@ -29,7 +31,7 @@ class WorkflowRuntimeTypeConverter: # Convert Decimal to float for JSON serialization return float(value) if isinstance(value, Segment): - return self._to_json_encodable_recursive(value.value) + return self.value_to_json_encodable_recursive(value.value) if isinstance(value, File): return value.to_dict() if isinstance(value, BaseModel): @@ -37,11 +39,11 @@ class WorkflowRuntimeTypeConverter: if isinstance(value, dict): res = {} for k, v in value.items(): - res[k] = self._to_json_encodable_recursive(v) + res[k] = self.value_to_json_encodable_recursive(v) return res if isinstance(value, list): res_list = [] for item in value: - res_list.append(self._to_json_encodable_recursive(item)) + res_list.append(self.value_to_json_encodable_recursive(item)) return res_list return value diff --git a/api/scripts/workflow_event_subscriber.py b/api/scripts/workflow_event_subscriber.py index d149ae75bf..579896315f 100644 --- a/api/scripts/workflow_event_subscriber.py +++ b/api/scripts/workflow_event_subscriber.py @@ -67,7 +67,7 @@ def _print_event(event: Mapping | str) -> None: payload = json.dumps(event, ensure_ascii=False) else: payload = event - print(payload) + # print(payload) sys.stdout.flush() diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 6731cd0a11..269e5c21b1 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -13,6 +13,7 @@ from core.workflow.nodes.human_input.entities import ( EmailDeliveryMethod, EmailRecipients, ExternalRecipient, + FormDefinition, HumanInputNodeData, MemberRecipient, UserAction, @@ -22,6 +23,7 @@ from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, EmailMemberRecipientPayload, + HumanInputForm, HumanInputFormRecipient, RecipientType, ) @@ -68,6 +70,7 @@ def _build_form_params(delivery_methods: list[EmailDeliveryMethod]) -> FormCreat node_id="human-input-node", form_config=form_config, rendered_content="

Approve?

", + resolved_placeholder_values={}, ) @@ -156,3 +159,37 @@ class TestHumanInputFormRepositoryImplWithContainers: ] assert len(external_payloads) == 1 assert external_payloads[0].email == "external@example.com" + + def test_create_form_persists_placeholder_values(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, _ = _create_tenant_with_members( + db_session_with_containers, + member_emails=["prefill@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + resolved_values = {"greeting": "Hello!"} + params = FormCreateParams( + workflow_execution_id=str(uuid4()), + node_id="human-input-node", + form_config=HumanInputNodeData( + title="Human Approval", + form_content="

Approve?

", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + ), + rendered_content="

Approve?

", + resolved_placeholder_values=resolved_values, + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + form_model = verification_session.scalars( + select(HumanInputForm).where(HumanInputForm.id == form_entity.id) + ).first() + + assert form_model is not None + definition = FormDefinition.model_validate_json(form_model.form_definition) + assert definition.placeholder_values == resolved_values diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py index e9ca28b8d6..c9c02ec992 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py @@ -1,26 +1,7 @@ -import sys -from types import ModuleType, SimpleNamespace +from types import SimpleNamespace from unittest.mock import MagicMock -if "core.ops.ops_trace_manager" not in sys.modules: - stub_module = ModuleType("core.ops.ops_trace_manager") - - class _StubTraceQueueManager: - def __init__(self, *_, **__): - pass - - stub_module.TraceQueueManager = _StubTraceQueueManager - sys.modules["core.ops.ops_trace_manager"] = stub_module - from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator -from tests.unit_tests.core.workflow.graph_engine.test_pause_resume_state import ( - _build_pausing_graph, - _build_runtime_state, - _node_successes, - _PausingNode, - _PausingNodeData, - _run_graph, -) def test_should_prepare_user_inputs_defaults_to_true(): diff --git a/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py index aedfe723e5..86c80985c4 100644 --- a/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py @@ -52,7 +52,9 @@ def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = N ) -def _create_advanced_chat_generate_entity(trace_manager: TraceQueueManager | None = None) -> AdvancedChatAppGenerateEntity: +def _create_advanced_chat_generate_entity( + trace_manager: TraceQueueManager | None = None, +) -> AdvancedChatAppGenerateEntity: return AdvancedChatAppGenerateEntity( task_id="advanced-task", app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT), diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py index d4dc91c4b4..6144df06e0 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py @@ -30,9 +30,16 @@ class TestPauseReasonDiscriminator: "TYPE": "human_input_required", "form_id": "form_id", "form_content": "form_content", + "node_id": "node_id", + "node_title": "node_title", }, }, - HumanInputRequired(form_id="form_id", form_content="form_content"), + HumanInputRequired( + form_id="form_id", + form_content="form_content", + node_id="node_id", + node_title="node_title", + ), id="HumanInputRequired", ), pytest.param( @@ -57,7 +64,12 @@ class TestPauseReasonDiscriminator: @pytest.mark.parametrize( "reason", [ - HumanInputRequired(form_id="form_id", form_content="form_content"), + HumanInputRequired( + form_id="form_id", + form_content="form_content", + node_id="node_id", + node_title="node_title", + ), SchedulingPause(message="Hold on"), ], ids=lambda x: type(x).__name__, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 85553e410e..8058432d8c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -251,7 +251,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_form_entity.rendered_content = "rendered" mock_create_repo.create_form.return_value = mock_form_entity - def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]: + def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: return _build_branching_graph(mock_config, mock_create_repo) initial_case = WorkflowTestCase( @@ -303,7 +303,9 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_get_repo.get_form_submission.return_value = mock_form_submission mock_get_repo.get_form.return_value = mock_form_entity - def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: + def resume_graph_factory( + initial_result=initial_result, mock_get_repo=mock_get_repo + ) -> tuple[Graph, GraphRuntimeState]: assert initial_result.graph_runtime_state is not None serialized_runtime_state = initial_result.graph_runtime_state.dumps() resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 73d109f033..ba24c6a679 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -2,9 +2,14 @@ Unit tests for human input node entities. """ +from types import SimpleNamespace +from unittest.mock import MagicMock + import pytest from pydantic import ValidationError +from core.workflow.entities import GraphInitParams +from core.workflow.node_events import PauseRequestedEvent from core.workflow.nodes.human_input.entities import ( ButtonStyle, DeliveryMethodType, @@ -24,6 +29,10 @@ from core.workflow.nodes.human_input.entities import ( WebAppDeliveryMethod, _WebAppDeliveryConfig, ) +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable class TestDeliveryMethod: @@ -263,6 +272,82 @@ class TestRecipients: assert recipients.items[1].email == "external@example.com" +class TestHumanInputNodeVariableResolution: + """Tests for resolving variable-based placeholders in HumanInputNode.""" + + def test_resolves_variable_placeholders(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + variable_pool.add(("start", "name"), "Jane Doe") + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Provide your name", + inputs=[ + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="user_name", + placeholder=FormInputPlaceholder(type=PlaceholderType.VARIABLE, selector=["start", "name"]), + ), + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="user_email", + placeholder=FormInputPlaceholder(type=PlaceholderType.CONSTANT, value="foo@example.com"), + ), + ], + user_actions=[UserAction(id="submit", title="Submit")], + ) + config = {"id": "human", "data": node_data.model_dump()} + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.get_form_submission.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-1", + rendered_content="Provide your name", + web_app_token="token", + recipients=[], + ) + + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=mock_repo, + ) + node.init_node_data(config["data"]) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + expected_values = {"user_name": "Jane Doe"} + assert pause_event.reason.resolved_placeholder_values == expected_values + + params = mock_repo.create_form.call_args.args[0] + assert params.resolved_placeholder_values == expected_values + + class TestValidation: """Test validation scenarios."""