mirror of
https://github.com/langgenius/dify.git
synced 2026-03-30 02:20:16 +08:00
resume test
This commit is contained in:
@ -1,19 +1,29 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events.base import NodeEventBase
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.repositories.human_input_form_repository import FormCreateParams, HumanInputFormRepository
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -34,6 +44,28 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
_form_repository: HumanInputFormRepository
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if form_repository is None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
self._form_repository = form_repository
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -86,19 +118,44 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
|
||||
return None
|
||||
|
||||
def _create_form_repository(self) -> HumanInputFormRepository:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]:
|
||||
yield event
|
||||
|
||||
@property
|
||||
def _workflow_execution_id(self) -> str:
|
||||
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
assert workflow_exec_id is not None
|
||||
return workflow_exec_id
|
||||
|
||||
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
|
||||
required_event = self._human_input_required_event(form_entity)
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
return pause_requested_event
|
||||
|
||||
def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult:
|
||||
try:
|
||||
params = FormCreateParams(
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self._render_form_content(),
|
||||
)
|
||||
form_entity = self._form_repository.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
form_entity.id,
|
||||
)
|
||||
yield self._human
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
except Exception as e:
|
||||
logger.exception("Human Input node failed to execute, node_id=%s", self.id)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="HumanInputNodeError",
|
||||
)
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Execute the human input node.
|
||||
@ -111,51 +168,26 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
5. Suspend workflow execution
|
||||
6. Wait for form submission to resume
|
||||
"""
|
||||
repo = self._create_form_repository()
|
||||
submission_result = repo.get_form_submission(self._workflow_execution_id, self.app_id)
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
if form is None:
|
||||
return self._create_form()
|
||||
|
||||
submission_result = repo.get_form_submission(form.id)
|
||||
if submission_result:
|
||||
outputs: dict[str, Any] = dict(submission_result.form_data())
|
||||
outputs["action_id"] = submission_result.selected_action_id
|
||||
outputs["__action_id"] = submission_result.selected_action_id
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"action_id": submission_result.selected_action_id,
|
||||
},
|
||||
outputs=outputs,
|
||||
edge_source_handle=submission_result.selected_action_id,
|
||||
)
|
||||
try:
|
||||
repo = self._create_form_repository()
|
||||
params = FormCreateParams(
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self._render_form_content(),
|
||||
)
|
||||
result = repo.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
required_event = HumanInputRequired(
|
||||
form_id=result.id,
|
||||
form_content=self._node_data.form_content,
|
||||
inputs=self._node_data.inputs,
|
||||
web_app_form_token=result.web_app_token,
|
||||
)
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
return self._pause_with_form(form)
|
||||
|
||||
# Create workflow suspended event
|
||||
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
result.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Human Input node failed to execute, node_id=%s", self.id)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="HumanInputNodeError",
|
||||
)
|
||||
return self._pause_generator(pause_requested_event)
|
||||
def _pause_with_form(self, form_entity: HumanInputFormEntity) -> Generator[NodeEventBase, None, None]:
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
|
||||
def _render_form_content(self) -> str:
|
||||
"""
|
||||
|
||||
@ -93,13 +93,18 @@ class HumanInputFormRepository(Protocol):
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
"""Get the form created for a given human input node in a workflow execution. Returns
|
||||
`None` if the form has not been created yet."""
|
||||
...
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
"""
|
||||
Create a human input form from form definition.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None:
|
||||
def get_form_submission(self, form_id: str) -> FormSubmission | None:
|
||||
"""Retrieve the submission for a specific human input node.
|
||||
|
||||
Returns `FormSubmission` if the form has been submitted, or `None` if not.
|
||||
|
||||
Reference in New Issue
Block a user