mirror of
https://github.com/langgenius/dify.git
synced 2026-04-20 10:47:21 +08:00
feat: support variable resolution, fix linting
This commit is contained in:
@ -537,8 +537,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
)
|
||||
for response in responses:
|
||||
yield response
|
||||
yield from responses
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_failed_event(
|
||||
|
||||
@ -61,9 +61,7 @@ class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
|
||||
@property
|
||||
def token(self) -> str:
|
||||
if self._recipient_model.access_token is None:
|
||||
raise AssertionError(
|
||||
f"access_token should not be None for recipient {self._recipient_model.id}"
|
||||
)
|
||||
raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}")
|
||||
return self._recipient_model.access_token
|
||||
|
||||
|
||||
@ -309,6 +307,7 @@ class HumanInputFormRepositoryImpl:
|
||||
rendered_content=params.rendered_content,
|
||||
timeout=form_config.timeout,
|
||||
timeout_unit=form_config.timeout_unit,
|
||||
placeholder_values=dict(params.resolved_placeholder_values),
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id=form_id,
|
||||
@ -345,9 +344,7 @@ class HumanInputFormRepositoryImpl:
|
||||
if form_model is None:
|
||||
return None
|
||||
|
||||
recipient_query = select(HumanInputFormRecipient).where(
|
||||
HumanInputFormRecipient.form_id == form_model.id
|
||||
)
|
||||
recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id)
|
||||
recipient_models = session.scalars(recipient_query).all()
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -19,6 +20,22 @@ class HumanInputRequired(BaseModel):
|
||||
actions: list[UserAction] = Field(default_factory=list)
|
||||
node_id: str
|
||||
node_title: str
|
||||
|
||||
# The `resolved_placeholder_values` stores the resolved values of variable placeholders. It's a mapping from
|
||||
# `output_variable_name` to their resolved values.
|
||||
#
|
||||
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
|
||||
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
|
||||
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
|
||||
# `resolved_placeholder_values` is `{"name": "John"}`.
|
||||
#
|
||||
# Only form inputs with placeholder type `VARIABLE` will be resolved and stored in `resolved_placeholder_values`.
|
||||
resolved_placeholder_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# The `web_app_form_token` is the token used to submit the form via webapp. It corresponds to
|
||||
# `HumanInputFormRecipient.access_token`.
|
||||
#
|
||||
# This field is `None` if webapp delivery is not set.
|
||||
web_app_form_token: str | None = None
|
||||
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, Optional, Self
|
||||
from typing import Annotated, Any, Literal, Optional, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
@ -278,3 +278,6 @@ class FormDefinition(BaseModel):
|
||||
|
||||
timeout: int
|
||||
timeout_unit: TimeoutUnit
|
||||
|
||||
# this is used to store the values of the placeholders
|
||||
placeholder_values: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@ -13,9 +13,10 @@ from core.workflow.repositories.human_input_form_repository import (
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
from .entities import HumanInputNodeData, PlaceholderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
@ -130,8 +131,27 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
return pause_requested_event
|
||||
|
||||
def _resolve_inputs(self) -> Mapping[str, Any]:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
resolved_inputs = {}
|
||||
for input in self._node_data.inputs:
|
||||
if (placeholder := input.placeholder) is None:
|
||||
continue
|
||||
if placeholder.type == PlaceholderType.CONSTANT:
|
||||
continue
|
||||
placeholder_value = variable_pool.get(placeholder.selector)
|
||||
if placeholder_value is None:
|
||||
# TODO: How should we handle this?
|
||||
continue
|
||||
resolved_inputs[input.output_variable_name] = (
|
||||
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(placeholder_value.value)
|
||||
)
|
||||
|
||||
return resolved_inputs
|
||||
|
||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
||||
node_data = self._node_data
|
||||
resolved_placeholder_values = self._resolve_inputs()
|
||||
return HumanInputRequired(
|
||||
form_id=form_entity.id,
|
||||
form_content=form_entity.rendered_content,
|
||||
@ -140,6 +160,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
node_id=self.id,
|
||||
node_title=node_data.title,
|
||||
web_app_form_token=form_entity.web_app_token,
|
||||
resolved_placeholder_values=resolved_placeholder_values,
|
||||
)
|
||||
|
||||
def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult:
|
||||
@ -149,6 +170,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self._render_form_content(),
|
||||
resolved_placeholder_values=self._resolve_inputs(),
|
||||
)
|
||||
form_entity = self._form_repository.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
@ -28,6 +28,12 @@ class FormCreateParams:
|
||||
form_config: HumanInputNodeData
|
||||
rendered_content: str
|
||||
|
||||
# resolved_placeholder_values saves the values for placeholders with
|
||||
# type = VARIABLE.
|
||||
#
|
||||
# For type = CONSTANT, the value is not stored inside `resolved_placeholder_values`
|
||||
resolved_placeholder_values: Mapping[str, Any]
|
||||
|
||||
|
||||
class HumanInputFormEntity(abc.ABC):
|
||||
@property
|
||||
|
||||
@ -15,12 +15,14 @@ class WorkflowRuntimeTypeConverter:
|
||||
def to_json_encodable(self, value: None) -> None: ...
|
||||
|
||||
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
result = self._to_json_encodable_recursive(value)
|
||||
"""Convert runtime values to JSON-serializable structures."""
|
||||
|
||||
result = self.value_to_json_encodable_recursive(value)
|
||||
if isinstance(result, Mapping) or result is None:
|
||||
return result
|
||||
return {}
|
||||
|
||||
def _to_json_encodable_recursive(self, value: Any):
|
||||
def value_to_json_encodable_recursive(self, value: Any):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, (bool, int, str, float)):
|
||||
@ -29,7 +31,7 @@ class WorkflowRuntimeTypeConverter:
|
||||
# Convert Decimal to float for JSON serialization
|
||||
return float(value)
|
||||
if isinstance(value, Segment):
|
||||
return self._to_json_encodable_recursive(value.value)
|
||||
return self.value_to_json_encodable_recursive(value.value)
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
if isinstance(value, BaseModel):
|
||||
@ -37,11 +39,11 @@ class WorkflowRuntimeTypeConverter:
|
||||
if isinstance(value, dict):
|
||||
res = {}
|
||||
for k, v in value.items():
|
||||
res[k] = self._to_json_encodable_recursive(v)
|
||||
res[k] = self.value_to_json_encodable_recursive(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list = []
|
||||
for item in value:
|
||||
res_list.append(self._to_json_encodable_recursive(item))
|
||||
res_list.append(self.value_to_json_encodable_recursive(item))
|
||||
return res_list
|
||||
return value
|
||||
|
||||
@ -67,7 +67,7 @@ def _print_event(event: Mapping | str) -> None:
|
||||
payload = json.dumps(event, ensure_ascii=False)
|
||||
else:
|
||||
payload = event
|
||||
print(payload)
|
||||
# print(payload)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
@ -22,6 +23,7 @@ from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
@ -68,6 +70,7 @@ def _build_form_params(delivery_methods: list[EmailDeliveryMethod]) -> FormCreat
|
||||
node_id="human-input-node",
|
||||
form_config=form_config,
|
||||
rendered_content="<p>Approve?</p>",
|
||||
resolved_placeholder_values={},
|
||||
)
|
||||
|
||||
|
||||
@ -156,3 +159,37 @@ class TestHumanInputFormRepositoryImplWithContainers:
|
||||
]
|
||||
assert len(external_payloads) == 1
|
||||
assert external_payloads[0].email == "external@example.com"
|
||||
|
||||
def test_create_form_persists_placeholder_values(self, db_session_with_containers: Session) -> None:
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
tenant, _ = _create_tenant_with_members(
|
||||
db_session_with_containers,
|
||||
member_emails=["prefill@example.com"],
|
||||
)
|
||||
|
||||
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
|
||||
resolved_values = {"greeting": "Hello!"}
|
||||
params = FormCreateParams(
|
||||
workflow_execution_id=str(uuid4()),
|
||||
node_id="human-input-node",
|
||||
form_config=HumanInputNodeData(
|
||||
title="Human Approval",
|
||||
form_content="<p>Approve?</p>",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
),
|
||||
rendered_content="<p>Approve?</p>",
|
||||
resolved_placeholder_values=resolved_values,
|
||||
)
|
||||
|
||||
form_entity = repository.create_form(params)
|
||||
|
||||
with Session(engine) as verification_session:
|
||||
form_model = verification_session.scalars(
|
||||
select(HumanInputForm).where(HumanInputForm.id == form_entity.id)
|
||||
).first()
|
||||
|
||||
assert form_model is not None
|
||||
definition = FormDefinition.model_validate_json(form_model.form_definition)
|
||||
assert definition.placeholder_values == resolved_values
|
||||
|
||||
@ -1,26 +1,7 @@
|
||||
import sys
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
if "core.ops.ops_trace_manager" not in sys.modules:
|
||||
stub_module = ModuleType("core.ops.ops_trace_manager")
|
||||
|
||||
class _StubTraceQueueManager:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
stub_module.TraceQueueManager = _StubTraceQueueManager
|
||||
sys.modules["core.ops.ops_trace_manager"] = stub_module
|
||||
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_pause_resume_state import (
|
||||
_build_pausing_graph,
|
||||
_build_runtime_state,
|
||||
_node_successes,
|
||||
_PausingNode,
|
||||
_PausingNodeData,
|
||||
_run_graph,
|
||||
)
|
||||
|
||||
|
||||
def test_should_prepare_user_inputs_defaults_to_true():
|
||||
|
||||
@ -52,7 +52,9 @@ def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = N
|
||||
)
|
||||
|
||||
|
||||
def _create_advanced_chat_generate_entity(trace_manager: TraceQueueManager | None = None) -> AdvancedChatAppGenerateEntity:
|
||||
def _create_advanced_chat_generate_entity(
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> AdvancedChatAppGenerateEntity:
|
||||
return AdvancedChatAppGenerateEntity(
|
||||
task_id="advanced-task",
|
||||
app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT),
|
||||
|
||||
@ -30,9 +30,16 @@ class TestPauseReasonDiscriminator:
|
||||
"TYPE": "human_input_required",
|
||||
"form_id": "form_id",
|
||||
"form_content": "form_content",
|
||||
"node_id": "node_id",
|
||||
"node_title": "node_title",
|
||||
},
|
||||
},
|
||||
HumanInputRequired(form_id="form_id", form_content="form_content"),
|
||||
HumanInputRequired(
|
||||
form_id="form_id",
|
||||
form_content="form_content",
|
||||
node_id="node_id",
|
||||
node_title="node_title",
|
||||
),
|
||||
id="HumanInputRequired",
|
||||
),
|
||||
pytest.param(
|
||||
@ -57,7 +64,12 @@ class TestPauseReasonDiscriminator:
|
||||
@pytest.mark.parametrize(
|
||||
"reason",
|
||||
[
|
||||
HumanInputRequired(form_id="form_id", form_content="form_content"),
|
||||
HumanInputRequired(
|
||||
form_id="form_id",
|
||||
form_content="form_content",
|
||||
node_id="node_id",
|
||||
node_title="node_title",
|
||||
),
|
||||
SchedulingPause(message="Hold on"),
|
||||
],
|
||||
ids=lambda x: type(x).__name__,
|
||||
|
||||
@ -251,7 +251,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
mock_form_entity.rendered_content = "rendered"
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_branching_graph(mock_config, mock_create_repo)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
@ -303,7 +303,9 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
mock_get_repo.get_form_submission.return_value = mock_form_submission
|
||||
mock_get_repo.get_form.return_value = mock_form_entity
|
||||
|
||||
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
def resume_graph_factory(
|
||||
initial_result=initial_result, mock_get_repo=mock_get_repo
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
assert initial_result.graph_runtime_state is not None
|
||||
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
|
||||
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
|
||||
|
||||
@ -2,9 +2,14 @@
|
||||
Unit tests for human input node entities.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.node_events import PauseRequestedEvent
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
ButtonStyle,
|
||||
DeliveryMethodType,
|
||||
@ -24,6 +29,10 @@ from core.workflow.nodes.human_input.entities import (
|
||||
WebAppDeliveryMethod,
|
||||
_WebAppDeliveryConfig,
|
||||
)
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
class TestDeliveryMethod:
|
||||
@ -263,6 +272,82 @@ class TestRecipients:
|
||||
assert recipients.items[1].email == "external@example.com"
|
||||
|
||||
|
||||
class TestHumanInputNodeVariableResolution:
|
||||
"""Tests for resolving variable-based placeholders in HumanInputNode."""
|
||||
|
||||
def test_resolves_variable_placeholders(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("start", "name"), "Jane Doe")
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Provide your name",
|
||||
inputs=[
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="user_name",
|
||||
placeholder=FormInputPlaceholder(type=PlaceholderType.VARIABLE, selector=["start", "name"]),
|
||||
),
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="user_email",
|
||||
placeholder=FormInputPlaceholder(type=PlaceholderType.CONSTANT, value="foo@example.com"),
|
||||
),
|
||||
],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
mock_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_repo.get_form.return_value = None
|
||||
mock_repo.get_form_submission.return_value = None
|
||||
mock_repo.create_form.return_value = SimpleNamespace(
|
||||
id="form-1",
|
||||
rendered_content="Provide your name",
|
||||
web_app_token="token",
|
||||
recipients=[],
|
||||
)
|
||||
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=mock_repo,
|
||||
)
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
run_result = node._run()
|
||||
pause_event = next(run_result)
|
||||
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
expected_values = {"user_name": "Jane Doe"}
|
||||
assert pause_event.reason.resolved_placeholder_values == expected_values
|
||||
|
||||
params = mock_repo.create_form.call_args.args[0]
|
||||
assert params.resolved_placeholder_values == expected_values
|
||||
|
||||
|
||||
class TestValidation:
|
||||
"""Test validation scenarios."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user