feat: support variable resolution, fix linting

This commit is contained in:
QuantumGhost
2025-12-05 02:44:04 +08:00
parent 23c6afe790
commit 08175ab32a
14 changed files with 207 additions and 42 deletions

View File

@ -537,8 +537,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event=event,
task_id=self._application_generate_entity.task_id,
)
for response in responses:
yield response
yield from responses
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_failed_event(

View File

@ -61,9 +61,7 @@ class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
@property
def token(self) -> str:
if self._recipient_model.access_token is None:
raise AssertionError(
f"access_token should not be None for recipient {self._recipient_model.id}"
)
raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}")
return self._recipient_model.access_token
@ -309,6 +307,7 @@ class HumanInputFormRepositoryImpl:
rendered_content=params.rendered_content,
timeout=form_config.timeout,
timeout_unit=form_config.timeout_unit,
placeholder_values=dict(params.resolved_placeholder_values),
)
form_model = HumanInputForm(
id=form_id,
@ -345,9 +344,7 @@ class HumanInputFormRepositoryImpl:
if form_model is None:
return None
recipient_query = select(HumanInputFormRecipient).where(
HumanInputFormRecipient.form_id == form_model.id
)
recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id)
recipient_models = session.scalars(recipient_query).all()
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)

View File

@ -1,5 +1,6 @@
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Annotated, Literal, TypeAlias
from typing import Annotated, Any, Literal, TypeAlias
from pydantic import BaseModel, Field
@ -19,6 +20,22 @@ class HumanInputRequired(BaseModel):
actions: list[UserAction] = Field(default_factory=list)
node_id: str
node_title: str
# The `resolved_placeholder_values` stores the resolved values of variable placeholders. It's a mapping from
# `output_variable_name` to their resolved values.
#
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
# `resolved_placeholder_values` is `{"name": "John"}`.
#
# Only form inputs with placeholder type `VARIABLE` will be resolved and stored in `resolved_placeholder_values`.
resolved_placeholder_values: Mapping[str, Any] = Field(default_factory=dict)
# The `web_app_form_token` is the token used to submit the form via webapp. It corresponds to
# `HumanInputFormRecipient.access_token`.
#
# This field is `None` if webapp delivery is not set.
web_app_form_token: str | None = None

View File

@ -8,7 +8,7 @@ import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from enum import StrEnum
from typing import Annotated, Literal, Optional, Self
from typing import Annotated, Any, Literal, Optional, Self
from pydantic import BaseModel, Field, field_validator, model_validator
@ -278,3 +278,6 @@ class FormDefinition(BaseModel):
timeout: int
timeout_unit: TimeoutUnit
# this is used to store the values of the placeholders
placeholder_values: dict[str, Any] = Field(default_factory=dict)

View File

@ -13,9 +13,10 @@ from core.workflow.repositories.human_input_form_repository import (
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from .entities import HumanInputNodeData
from .entities import HumanInputNodeData, PlaceholderType
if TYPE_CHECKING:
from core.workflow.entities.graph_init_params import GraphInitParams
@ -130,8 +131,27 @@ class HumanInputNode(Node[HumanInputNodeData]):
pause_requested_event = PauseRequestedEvent(reason=required_event)
return pause_requested_event
def _resolve_inputs(self) -> Mapping[str, Any]:
variable_pool = self.graph_runtime_state.variable_pool
resolved_inputs = {}
for input in self._node_data.inputs:
if (placeholder := input.placeholder) is None:
continue
if placeholder.type == PlaceholderType.CONSTANT:
continue
placeholder_value = variable_pool.get(placeholder.selector)
if placeholder_value is None:
# TODO: How should we handle this?
continue
resolved_inputs[input.output_variable_name] = (
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(placeholder_value.value)
)
return resolved_inputs
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
node_data = self._node_data
resolved_placeholder_values = self._resolve_inputs()
return HumanInputRequired(
form_id=form_entity.id,
form_content=form_entity.rendered_content,
@ -140,6 +160,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
node_id=self.id,
node_title=node_data.title,
web_app_form_token=form_entity.web_app_token,
resolved_placeholder_values=resolved_placeholder_values,
)
def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult:
@ -149,6 +170,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
node_id=self.id,
form_config=self._node_data,
rendered_content=self._render_form_content(),
resolved_placeholder_values=self._resolve_inputs(),
)
form_entity = self._form_repository.create_form(params)
# Create human input required event

View File

@ -28,6 +28,12 @@ class FormCreateParams:
form_config: HumanInputNodeData
rendered_content: str
# resolved_placeholder_values saves the values for placeholders with
# type = VARIABLE.
#
# For type = CONSTANT, the value is not stored inside `resolved_placeholder_values`
resolved_placeholder_values: Mapping[str, Any]
class HumanInputFormEntity(abc.ABC):
@property

View File

@ -15,12 +15,14 @@ class WorkflowRuntimeTypeConverter:
def to_json_encodable(self, value: None) -> None: ...
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
"""Convert runtime values to JSON-serializable structures."""
result = self.value_to_json_encodable_recursive(value)
if isinstance(result, Mapping) or result is None:
return result
return {}
def _to_json_encodable_recursive(self, value: Any):
def value_to_json_encodable_recursive(self, value: Any):
if value is None:
return value
if isinstance(value, (bool, int, str, float)):
@ -29,7 +31,7 @@ class WorkflowRuntimeTypeConverter:
# Convert Decimal to float for JSON serialization
return float(value)
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
return self.value_to_json_encodable_recursive(value.value)
if isinstance(value, File):
return value.to_dict()
if isinstance(value, BaseModel):
@ -37,11 +39,11 @@ class WorkflowRuntimeTypeConverter:
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = self._to_json_encodable_recursive(v)
res[k] = self.value_to_json_encodable_recursive(v)
return res
if isinstance(value, list):
res_list = []
for item in value:
res_list.append(self._to_json_encodable_recursive(item))
res_list.append(self.value_to_json_encodable_recursive(item))
return res_list
return value

View File

@ -67,7 +67,7 @@ def _print_event(event: Mapping | str) -> None:
payload = json.dumps(event, ensure_ascii=False)
else:
payload = event
print(payload)
# print(payload)
sys.stdout.flush()

View File

@ -13,6 +13,7 @@ from core.workflow.nodes.human_input.entities import (
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormDefinition,
HumanInputNodeData,
MemberRecipient,
UserAction,
@ -22,6 +23,7 @@ from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.human_input import (
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
@ -68,6 +70,7 @@ def _build_form_params(delivery_methods: list[EmailDeliveryMethod]) -> FormCreat
node_id="human-input-node",
form_config=form_config,
rendered_content="<p>Approve?</p>",
resolved_placeholder_values={},
)
@ -156,3 +159,37 @@ class TestHumanInputFormRepositoryImplWithContainers:
]
assert len(external_payloads) == 1
assert external_payloads[0].email == "external@example.com"
def test_create_form_persists_placeholder_values(self, db_session_with_containers: Session) -> None:
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
tenant, _ = _create_tenant_with_members(
db_session_with_containers,
member_emails=["prefill@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
resolved_values = {"greeting": "Hello!"}
params = FormCreateParams(
workflow_execution_id=str(uuid4()),
node_id="human-input-node",
form_config=HumanInputNodeData(
title="Human Approval",
form_content="<p>Approve?</p>",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
),
rendered_content="<p>Approve?</p>",
resolved_placeholder_values=resolved_values,
)
form_entity = repository.create_form(params)
with Session(engine) as verification_session:
form_model = verification_session.scalars(
select(HumanInputForm).where(HumanInputForm.id == form_entity.id)
).first()
assert form_model is not None
definition = FormDefinition.model_validate_json(form_model.form_definition)
assert definition.placeholder_values == resolved_values

View File

@ -1,26 +1,7 @@
import sys
from types import ModuleType, SimpleNamespace
from types import SimpleNamespace
from unittest.mock import MagicMock
if "core.ops.ops_trace_manager" not in sys.modules:
stub_module = ModuleType("core.ops.ops_trace_manager")
class _StubTraceQueueManager:
def __init__(self, *_, **__):
pass
stub_module.TraceQueueManager = _StubTraceQueueManager
sys.modules["core.ops.ops_trace_manager"] = stub_module
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from tests.unit_tests.core.workflow.graph_engine.test_pause_resume_state import (
_build_pausing_graph,
_build_runtime_state,
_node_successes,
_PausingNode,
_PausingNodeData,
_run_graph,
)
def test_should_prepare_user_inputs_defaults_to_true():

View File

@ -52,7 +52,9 @@ def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = N
)
def _create_advanced_chat_generate_entity(trace_manager: TraceQueueManager | None = None) -> AdvancedChatAppGenerateEntity:
def _create_advanced_chat_generate_entity(
trace_manager: TraceQueueManager | None = None,
) -> AdvancedChatAppGenerateEntity:
return AdvancedChatAppGenerateEntity(
task_id="advanced-task",
app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT),

View File

@ -30,9 +30,16 @@ class TestPauseReasonDiscriminator:
"TYPE": "human_input_required",
"form_id": "form_id",
"form_content": "form_content",
"node_id": "node_id",
"node_title": "node_title",
},
},
HumanInputRequired(form_id="form_id", form_content="form_content"),
HumanInputRequired(
form_id="form_id",
form_content="form_content",
node_id="node_id",
node_title="node_title",
),
id="HumanInputRequired",
),
pytest.param(
@ -57,7 +64,12 @@ class TestPauseReasonDiscriminator:
@pytest.mark.parametrize(
"reason",
[
HumanInputRequired(form_id="form_id", form_content="form_content"),
HumanInputRequired(
form_id="form_id",
form_content="form_content",
node_id="node_id",
node_title="node_title",
),
SchedulingPause(message="Hold on"),
],
ids=lambda x: type(x).__name__,

View File

@ -251,7 +251,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
mock_form_entity.rendered_content = "rendered"
mock_create_repo.create_form.return_value = mock_form_entity
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]:
return _build_branching_graph(mock_config, mock_create_repo)
initial_case = WorkflowTestCase(
@ -303,7 +303,9 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
mock_get_repo.get_form_submission.return_value = mock_form_submission
mock_get_repo.get_form.return_value = mock_form_entity
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
def resume_graph_factory(
initial_result=initial_result, mock_get_repo=mock_get_repo
) -> tuple[Graph, GraphRuntimeState]:
assert initial_result.graph_runtime_state is not None
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)

View File

@ -2,9 +2,14 @@
Unit tests for human input node entities.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pydantic import ValidationError
from core.workflow.entities import GraphInitParams
from core.workflow.node_events import PauseRequestedEvent
from core.workflow.nodes.human_input.entities import (
ButtonStyle,
DeliveryMethodType,
@ -24,6 +29,10 @@ from core.workflow.nodes.human_input.entities import (
WebAppDeliveryMethod,
_WebAppDeliveryConfig,
)
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
class TestDeliveryMethod:
@ -263,6 +272,82 @@ class TestRecipients:
assert recipients.items[1].email == "external@example.com"
class TestHumanInputNodeVariableResolution:
"""Tests for resolving variable-based placeholders in HumanInputNode."""
def test_resolves_variable_placeholders(self):
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
variable_pool.add(("start", "name"), "Jane Doe")
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
node_data = HumanInputNodeData(
title="Human Input",
form_content="Provide your name",
inputs=[
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="user_name",
placeholder=FormInputPlaceholder(type=PlaceholderType.VARIABLE, selector=["start", "name"]),
),
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="user_email",
placeholder=FormInputPlaceholder(type=PlaceholderType.CONSTANT, value="foo@example.com"),
),
],
user_actions=[UserAction(id="submit", title="Submit")],
)
config = {"id": "human", "data": node_data.model_dump()}
mock_repo = MagicMock(spec=HumanInputFormRepository)
mock_repo.get_form.return_value = None
mock_repo.get_form_submission.return_value = None
mock_repo.create_form.return_value = SimpleNamespace(
id="form-1",
rendered_content="Provide your name",
web_app_token="token",
recipients=[],
)
node = HumanInputNode(
id=config["id"],
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=mock_repo,
)
node.init_node_data(config["data"])
run_result = node._run()
pause_event = next(run_result)
assert isinstance(pause_event, PauseRequestedEvent)
expected_values = {"user_name": "Jane Doe"}
assert pause_event.reason.resolved_placeholder_values == expected_values
params = mock_repo.create_form.call_args.args[0]
assert params.resolved_placeholder_values == expected_values
class TestValidation:
"""Test validation scenarios."""