mirror of
https://github.com/langgenius/dify.git
synced 2026-02-22 19:15:47 +08:00
WIP: resume
This commit is contained in:
@ -107,6 +107,7 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Submit the form
|
||||
service = HumanInputService(db.engine)
|
||||
@ -114,6 +115,7 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
form_id=form_id,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
return jsonify({})
|
||||
|
||||
@ -70,6 +70,7 @@ class HumanInputFormApi(WebApiResource):
|
||||
form_token=web_app_form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
submission_end_user_id=_end_user.id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
@ -3,7 +3,6 @@ Web App Workflow Resume APIs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response
|
||||
@ -14,42 +13,6 @@ from controllers.web.wraps import WebApiResource
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowResumeWaitApi(WebApiResource):
|
||||
"""API for long-polling workflow resume wait."""
|
||||
|
||||
def get(self, task_id: str):
|
||||
"""
|
||||
Get workflow execution resume notification.
|
||||
|
||||
GET /api/workflow/<task_id>/resume-wait
|
||||
|
||||
This is a long-polling API that waits for workflow to resume from paused state.
|
||||
"""
|
||||
# TODO: Implement actual workflow status checking
|
||||
# For now, return a basic response
|
||||
|
||||
timeout = 30 # 30 seconds timeout for demo
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# TODO: Check workflow status from database/cache
|
||||
# workflow_status = workflow_service.get_status(task_id)
|
||||
|
||||
# For demo purposes, simulate different states
|
||||
# In real implementation, this would check the actual workflow state
|
||||
workflow_status = "paused" # or "running" or "ended"
|
||||
|
||||
if workflow_status == "running":
|
||||
return {"status": "running"}, 200
|
||||
elif workflow_status == "ended":
|
||||
return {"status": "ended"}, 200
|
||||
|
||||
time.sleep(1) # Poll every second
|
||||
|
||||
# Return paused status if timeout reached
|
||||
return {"status": "paused"}, 200
|
||||
|
||||
|
||||
class WorkflowEventsApi(WebApiResource):
|
||||
"""API for getting workflow execution events after resume."""
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, TypeVar, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
@ -24,16 +24,19 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -63,6 +66,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: uuid.UUID,
|
||||
streaming: Literal[False],
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -75,6 +79,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: uuid.UUID,
|
||||
streaming: Literal[True],
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@ -87,6 +92,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: uuid.UUID,
|
||||
streaming: bool,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
|
||||
|
||||
def generate(
|
||||
@ -98,6 +104,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: uuid.UUID,
|
||||
streaming: bool = True,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -215,6 +222,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Resume a paused advanced chat execution.
|
||||
"""
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
stream=application_generate_entity.stream,
|
||||
pause_state_config=pause_state_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def single_iteration_generate(
|
||||
@ -395,8 +434,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
conversation: Conversation | None = None,
|
||||
message: Message | None = None,
|
||||
stream: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -410,12 +453,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
is_first_conversation = conversation is None
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
if conversation is not None and message is not None:
|
||||
pass
|
||||
else:
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
@ -438,6 +481,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@ -453,6 +506,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
@ -498,6 +553,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -555,6 +612,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app=app,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -64,6 +64,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -80,6 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._app = app
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._resume_graph_runtime_state = graph_runtime_state
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@ -103,7 +105,19 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
resume_state = self._resume_graph_runtime_state
|
||||
|
||||
if resume_state is not None:
|
||||
graph_runtime_state = resume_state
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
|
||||
@ -23,6 +23,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -31,12 +32,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.account import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
@ -63,6 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@ -79,6 +84,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -95,6 +101,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@ -110,6 +117,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@ -210,13 +218,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(self, *, workflow_run_id: str) -> None:
|
||||
def resume(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
@TBD
|
||||
Resume a paused workflow execution using the persisted runtime state.
|
||||
"""
|
||||
pass
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=application_generate_entity.stream,
|
||||
variable_loader=variable_loader,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -232,6 +267,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -245,6 +282,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
"""
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
@ -253,6 +292,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@ -270,7 +318,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": graph_engine_layers,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
@ -372,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
pause_state_config=None,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
@ -453,6 +503,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
pause_state_config=None,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
@ -466,6 +517,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -511,6 +563,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -43,6 +43,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -56,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._resume_graph_runtime_state = graph_runtime_state
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@ -65,17 +67,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
resume_state = self._resume_graph_runtime_state
|
||||
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
if resume_state is not None:
|
||||
graph_runtime_state = resume_state
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
@ -85,7 +82,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
# Create a variable pool.
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
@ -95,7 +99,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
# init graph (for both resumed and fresh runs when not single-step)
|
||||
if resume_state is not None or not (
|
||||
self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run
|
||||
):
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel):
|
||||
return self.generate_entity.entity
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PauseStateLayerConfig:
|
||||
"""Configuration container for instantiating pause persistence layers."""
|
||||
|
||||
session_factory: Engine | sessionmaker[Session]
|
||||
state_owner_user_id: str
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
|
||||
@ -15,10 +15,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@ -31,11 +28,10 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.engine import db
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -470,6 +466,8 @@ class TraceTask:
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_run_repo(cls):
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
if cls._workflow_run_repo is None:
|
||||
with cls._repo_lock:
|
||||
if cls._workflow_run_repo is None:
|
||||
|
||||
@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.engine import db
|
||||
from models.model import Message
|
||||
|
||||
|
||||
|
||||
@ -2,27 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
_ATTRIBUTE_MODULE_MAP = {
|
||||
"CeleryWorkflowExecutionRepository": "core.repositories.celery_workflow_execution_repository",
|
||||
"CeleryWorkflowNodeExecutionRepository": "core.repositories.celery_workflow_node_execution_repository",
|
||||
"DifyCoreRepositoryFactory": "core.repositories.factory",
|
||||
"RepositoryImportError": "core.repositories.factory",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository": "core.repositories.sqlalchemy_workflow_node_execution_repository",
|
||||
}
|
||||
|
||||
__all__ = list(_ATTRIBUTE_MODULE_MAP.keys())
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
module_path = _ATTRIBUTE_MODULE_MAP.get(name)
|
||||
if module_path is None:
|
||||
raise AttributeError(f"module 'core.repositories' has no attribute '{name}'")
|
||||
module = import_module(module_path)
|
||||
return getattr(module, name)
|
||||
|
||||
|
||||
def __dir__() -> list[str]: # pragma: no cover - simple helper
|
||||
return sorted(__all__)
|
||||
__all__ = [
|
||||
"CeleryWorkflowExecutionRepository",
|
||||
"CeleryWorkflowNodeExecutionRepository",
|
||||
"DifyCoreRepositoryFactory",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowExecutionRepository",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import dataclasses
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker, selectinload
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
@ -85,6 +86,53 @@ class _FormSubmissionImpl(FormSubmission):
|
||||
return json.loads(submitted_data)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class HumanInputFormRecord:
|
||||
form_id: str
|
||||
workflow_run_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
definition: FormDefinition
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
selected_action_id: str | None
|
||||
submitted_data: Mapping[str, Any] | None
|
||||
submitted_at: datetime | None
|
||||
submission_user_id: str | None
|
||||
submission_end_user_id: str | None
|
||||
completed_by_recipient_id: str | None
|
||||
recipient_id: str | None
|
||||
recipient_type: RecipientType | None
|
||||
access_token: str | None
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.submitted_at is not None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None
|
||||
) -> "HumanInputFormRecord":
|
||||
return cls(
|
||||
form_id=form_model.id,
|
||||
workflow_run_id=form_model.workflow_run_id,
|
||||
node_id=form_model.node_id,
|
||||
tenant_id=form_model.tenant_id,
|
||||
definition=FormDefinition.model_validate_json(form_model.form_definition),
|
||||
rendered_content=form_model.rendered_content,
|
||||
expiration_time=form_model.expiration_time,
|
||||
selected_action_id=form_model.selected_action_id,
|
||||
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
|
||||
submitted_at=form_model.submitted_at,
|
||||
submission_user_id=form_model.submission_user_id,
|
||||
submission_end_user_id=form_model.submission_end_user_id,
|
||||
completed_by_recipient_id=form_model.completed_by_recipient_id,
|
||||
recipient_id=recipient_model.id if recipient_model else None,
|
||||
recipient_type=recipient_model.recipient_type if recipient_model else None,
|
||||
access_token=recipient_model.access_token if recipient_model else None,
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormRepositoryImpl:
|
||||
def __init__(
|
||||
self,
|
||||
@ -275,11 +323,74 @@ class HumanInputFormRepositoryImpl:
|
||||
|
||||
return _FormSubmissionImpl(form_model=form_model)
|
||||
|
||||
def get_form_by_token(self, token: str, recipient_type: RecipientType | None = None):
|
||||
|
||||
class HumanInputFormReadRepository:
|
||||
"""Read/write repository for fetching and submitting human input forms."""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where()
|
||||
|
||||
.where(HumanInputFormRecipient.access_token == form_token)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_recipient = session.qu
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
||||
|
||||
def get_by_form_id_and_recipient_type(
|
||||
self,
|
||||
form_id: str,
|
||||
recipient_type: RecipientType,
|
||||
) -> HumanInputFormRecord | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(
|
||||
HumanInputFormRecipient.form_id == form_id,
|
||||
HumanInputFormRecipient.recipient_type == recipient_type,
|
||||
)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
||||
|
||||
def mark_submitted(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
recipient_id: str | None,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
submission_user_id: str | None,
|
||||
submission_end_user_id: str | None,
|
||||
) -> HumanInputFormRecord:
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None
|
||||
|
||||
form_model.selected_action_id = selected_action_id
|
||||
form_model.submitted_data = json.dumps(form_data)
|
||||
form_model.submitted_at = naive_utc_now()
|
||||
form_model.submission_user_id = submission_user_id
|
||||
form_model.submission_end_user_id = submission_end_user_id
|
||||
form_model.completed_by_recipient_id = recipient_id
|
||||
|
||||
session.add(form_model)
|
||||
session.flush()
|
||||
session.refresh(form_model)
|
||||
if recipient_model is not None:
|
||||
session.refresh(recipient_model)
|
||||
|
||||
return HumanInputFormRecord.from_models(form_model, recipient_model)
|
||||
|
||||
@ -1,71 +0,0 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, TypeAlias, final
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandParams:
|
||||
# `next_node_instance` is the instance of the next node to run.
|
||||
next_node: BaseNode
|
||||
|
||||
|
||||
class _CommandTag(StrEnum):
|
||||
SUSPEND = "suspend"
|
||||
STOP = "stop"
|
||||
CONTINUE = "continue"
|
||||
|
||||
|
||||
# Note: Avoid using the `_Command` class directly.
|
||||
# Instead, use `CommandTypes` for type annotations.
|
||||
class _Command(BaseModel, abc.ABC):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
tag: _CommandTag
|
||||
|
||||
@field_validator("tag")
|
||||
@classmethod
|
||||
def validate_value_type(cls, value):
|
||||
if value != cls.model_fields["tag"].default:
|
||||
raise ValueError("Cannot modify 'tag'")
|
||||
return value
|
||||
|
||||
|
||||
@final
|
||||
class StopCommand(_Command):
|
||||
tag: _CommandTag = _CommandTag.STOP
|
||||
|
||||
|
||||
@final
|
||||
class SuspendCommand(_Command):
|
||||
tag: _CommandTag = _CommandTag.SUSPEND
|
||||
|
||||
|
||||
@final
|
||||
class ContinueCommand(_Command):
|
||||
tag: _CommandTag = _CommandTag.CONTINUE
|
||||
|
||||
|
||||
def _get_command_tag(command: _Command):
|
||||
return command.tag
|
||||
|
||||
|
||||
CommandTypes: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[StopCommand, Tag(_CommandTag.STOP)]
|
||||
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
|
||||
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
|
||||
),
|
||||
Discriminator(_get_command_tag),
|
||||
]
|
||||
|
||||
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
|
||||
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
|
||||
#
|
||||
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
|
||||
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]
|
||||
@ -1,8 +1,3 @@
|
||||
"""
|
||||
Human Input node implementation.
|
||||
"""
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
from .human_input_node import HumanInputNode
|
||||
|
||||
__all__ = ["HumanInputNode", "HumanInputNodeData"]
|
||||
|
||||
@ -269,16 +269,6 @@ class HumanInputNodeData(BaseNodeData):
|
||||
return variable_mappings
|
||||
|
||||
|
||||
class HumanInputRequired(BaseModel):
|
||||
"""Event data for human input required."""
|
||||
|
||||
form_id: str
|
||||
node_id: str
|
||||
form_content: str
|
||||
inputs: list[FormInput]
|
||||
web_app_form_token: Optional[str] = None
|
||||
|
||||
|
||||
class FormDefinition(BaseModel):
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
@ -18,11 +17,6 @@ _SELECTED_BRANCH_KEY = "selected_branch"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _FormSubmissionResult:
|
||||
action_id: str
|
||||
|
||||
|
||||
class HumanInputNode(Node[HumanInputNodeData]):
|
||||
node_type = NodeType.HUMAN_INPUT
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
@ -5,7 +5,7 @@ import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
@ -13,6 +13,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
class ReadyQueueProtocol(Protocol):
|
||||
"""Structural interface required from ready queue implementations."""
|
||||
@ -59,7 +62,7 @@ class GraphExecutionProtocol(Protocol):
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
pause_reasons: Sequence[PauseReason]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
|
||||
@ -57,18 +57,19 @@ class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
|
||||
completed_by_recipient_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID,
|
||||
sa.ForeignKey("human_input_recipients.id"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
deliveries: Mapped[list["HumanInputDelivery"]] = relationship(
|
||||
"HumanInputDelivery",
|
||||
primaryjoin="HumanInputForm.id == foreign(HumanInputDelivery.form_id)",
|
||||
uselist=True,
|
||||
back_populates="form",
|
||||
lazy="raise",
|
||||
)
|
||||
completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship(
|
||||
"HumanInputRecipient",
|
||||
primaryjoin="HumanInputForm.completed_by_recipient_id == HumanInputRecipient.id",
|
||||
"HumanInputFormRecipient",
|
||||
primaryjoin="HumanInputForm.completed_by_recipient_id == foreign(HumanInputFormRecipient.id)",
|
||||
lazy="raise",
|
||||
viewonly=True,
|
||||
)
|
||||
@ -79,7 +80,6 @@ class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
sa.ForeignKey("human_input_forms.id"),
|
||||
nullable=False,
|
||||
)
|
||||
delivery_method_type: Mapped[DeliveryMethodType] = mapped_column(
|
||||
@ -91,11 +91,16 @@ class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
|
||||
form: Mapped[HumanInputForm] = relationship(
|
||||
"HumanInputForm",
|
||||
uselist=False,
|
||||
foreign_keys=[form_id],
|
||||
primaryjoin="HumanInputDelivery.form_id == HumanInputForm.id",
|
||||
back_populates="deliveries",
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
recipients: Mapped[list["HumanInputFormRecipient"]] = relationship(
|
||||
"HumanInputRecipient",
|
||||
"HumanInputFormRecipient",
|
||||
primaryjoin="HumanInputDelivery.id == foreign(HumanInputFormRecipient.delivery_id)",
|
||||
uselist=True,
|
||||
back_populates="delivery",
|
||||
# Require explicit preloading
|
||||
@ -162,6 +167,7 @@ class HumanInputFormRecipient(DefaultFieldsMixin, Base):
|
||||
uselist=False,
|
||||
foreign_keys=[delivery_id],
|
||||
back_populates="recipients",
|
||||
primaryjoin="HumanInputFormRecipient.delivery_id == HumanInputDelivery.id",
|
||||
# Require explicit preloading
|
||||
lazy="raise",
|
||||
)
|
||||
@ -170,7 +176,7 @@ class HumanInputFormRecipient(DefaultFieldsMixin, Base):
|
||||
"HumanInputForm",
|
||||
uselist=False,
|
||||
foreign_keys=[form_id],
|
||||
back_populates="recipients",
|
||||
primaryjoin="HumanInputFormRecipient.form_id == HumanInputForm.id",
|
||||
# Require explicit preloading
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
@ -6,38 +6,11 @@ by the core workflow module. These models are independent of the storage mechani
|
||||
and don't contain implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Sequence
|
||||
|
||||
|
||||
class _PauseTypeEnum(enum.StrEnum):
|
||||
human_input = enum.auto()
|
||||
breakpoint = enum.auto()
|
||||
scheduling = enum.auto()
|
||||
|
||||
|
||||
class HumanInputPause(BaseModel):
|
||||
type: Literal[_PauseTypeEnum.human_input] = _PauseTypeEnum.human_input
|
||||
form_id: str
|
||||
|
||||
|
||||
class SchedulingPause(BaseModel):
|
||||
type: Literal[_PauseTypeEnum.scheduling] = _PauseTypeEnum.scheduling
|
||||
|
||||
|
||||
PauseType: TypeAlias = Annotated[HumanInputPause | SchedulingPause, Field(discriminator="type")]
|
||||
|
||||
|
||||
class PauseDetail(BaseModel):
|
||||
node_id: str
|
||||
node_title: str
|
||||
pause_type: PauseType
|
||||
from .pause_reason import PauseReason
|
||||
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
@ -100,7 +73,5 @@ class WorkflowPauseEntity(ABC):
|
||||
|
||||
Returns a sequence of `PauseReason` objects describing the specific nodes and
|
||||
reasons for which the workflow execution was paused.
|
||||
This information is related to, but distinct from, the `PauseReason` type
|
||||
defined in `api/core/workflow/entities/pause_reason.py`.
|
||||
"""
|
||||
...
|
||||
|
||||
@ -505,7 +505,6 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
# Mark as resumed
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
workflow_run.pause_id = None # type: ignore
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
|
||||
session.add(pause_model)
|
||||
|
||||
@ -84,3 +84,13 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||
)
|
||||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
|
||||
"""Get the trigger log associated with a workflow run."""
|
||||
query = (
|
||||
select(WorkflowTriggerLog)
|
||||
.where(WorkflowTriggerLog.workflow_run_id == workflow_run_id)
|
||||
.order_by(WorkflowTriggerLog.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return self.session.scalar(query)
|
||||
|
||||
@ -109,3 +109,15 @@ class WorkflowTriggerLogRepository(Protocol):
|
||||
A sequence of recent WorkflowTriggerLog instances
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
|
||||
"""
|
||||
Retrieve a trigger log associated with a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run
|
||||
|
||||
Returns:
|
||||
The matching WorkflowTriggerLog if present, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@ -14,12 +14,8 @@ from core.app.features.rate_limiting.rate_limit import rate_limit_context
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import ChatflowExecutionParams, chatflow_execute_task
|
||||
|
||||
@ -259,7 +255,7 @@ class AppGenerateService:
|
||||
):
|
||||
if workflow_run.status.is_ended():
|
||||
# TODO(QuantumGhost): handled the ended scenario.
|
||||
return
|
||||
pass
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
|
||||
@ -1,29 +1,49 @@
|
||||
import abc
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.exception import BaseHTTPException
|
||||
from models.account import Account
|
||||
from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from models.human_input import RecipientType
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.workflow.entities import WorkflowResumeTaskData
|
||||
from tasks.app_generate.workflow_execute_task import resume_chatflow_execution
|
||||
from tasks.async_workflow_tasks import resume_workflow_execution
|
||||
|
||||
|
||||
class Form:
|
||||
def __init__(self, form_model: HumanInputForm):
|
||||
self._form_model = form_model
|
||||
def __init__(self, record: HumanInputFormRecord):
|
||||
self._record = record
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_definition(self) -> FormDefinition:
|
||||
pass
|
||||
return self._record.definition
|
||||
|
||||
@abc.abstractmethod
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
pass
|
||||
return self._record.submitted
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._record.form_id
|
||||
|
||||
@property
|
||||
def workflow_run_id(self) -> str:
|
||||
return self._record.workflow_run_id
|
||||
|
||||
@property
|
||||
def recipient_id(self) -> str | None:
|
||||
return self._record.recipient_id
|
||||
|
||||
@property
|
||||
def recipient_type(self) -> RecipientType | None:
|
||||
return self._record.recipient_type
|
||||
|
||||
|
||||
class HumanInputError(Exception):
|
||||
@ -49,93 +69,148 @@ class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputService:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker[Session] | Engine,
|
||||
form_repository: HumanInputFormReadRepository | None = None,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._form_repository = form_repository or HumanInputFormReadRepository(session_factory)
|
||||
|
||||
def get_form_by_token(self, form_token: str) -> Form | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(HumanInputFormRecipient.access_token == form_token)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient = session.scalars(query).first()
|
||||
if recipient is None:
|
||||
record = self._form_repository.get_by_token(form_token)
|
||||
if record is None:
|
||||
return None
|
||||
return Form(record)
|
||||
|
||||
return Form(recipient.form)
|
||||
|
||||
def get_form_by_id(self, form_id: str) -> Form | None:
|
||||
query = select(HumanInputForm).where(HumanInputForm.id == form_id)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model = session.scalars(query).first()
|
||||
if form_model is None:
|
||||
return None
|
||||
|
||||
return Form(form_model)
|
||||
|
||||
def submit_form_by_id(self, form_id: str, user: Account, selected_action_id: str, form_data: Mapping[str, Any]):
|
||||
recipient_query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(
|
||||
HumanInputFormRecipient.recipient_type == RecipientType.WEBAPP,
|
||||
HumanInputFormRecipient.form_id == form_id,
|
||||
)
|
||||
def get_form_by_id(self, form_id: str, recipient_type: RecipientType = RecipientType.WEBAPP) -> Form | None:
|
||||
record = self._form_repository.get_by_form_id_and_recipient_type(
|
||||
form_id=form_id,
|
||||
recipient_type=recipient_type,
|
||||
)
|
||||
if record is None:
|
||||
return None
|
||||
return Form(record)
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(recipient_query).first()
|
||||
def get_form_definition_by_id(self, form_id: str) -> Form | None:
|
||||
form = self.get_form_by_id(form_id, recipient_type=RecipientType.WEBAPP)
|
||||
if form is None:
|
||||
return None
|
||||
self._ensure_not_submitted(form)
|
||||
return form
|
||||
|
||||
if recipient_model is None:
|
||||
def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None:
|
||||
form = self.get_form_by_token(form_token)
|
||||
if form is None or form.recipient_type != recipient_type:
|
||||
return None
|
||||
self._ensure_not_submitted(form)
|
||||
return form
|
||||
|
||||
def submit_form_by_id(
|
||||
self,
|
||||
form_id: str,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
user: Account | None = None,
|
||||
):
|
||||
form = self.get_form_by_id(form_id, recipient_type=RecipientType.WEBAPP)
|
||||
if form is None:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
form_model = recipient_model.form
|
||||
form = Form(form_model)
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form_model.id)
|
||||
self._ensure_not_submitted(form)
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model.selected_action_id = selected_action_id
|
||||
form_model.submitted_data = json.dumps(form_data)
|
||||
form_model.submitted_at = naive_utc_now()
|
||||
form_model.submission_user_id = user.id
|
||||
|
||||
form_model.completed_by_recipient_id = recipient_model.id
|
||||
session.add(form_model)
|
||||
# TODO: restart the execution of paused workflow
|
||||
|
||||
def submit_form_by_token(self, form_token: str, selected_action_id: str, form_data: Mapping[str, Any]):
|
||||
recipient_query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(
|
||||
HumanInputFormRecipient.form_id == form_token,
|
||||
)
|
||||
result = self._form_repository.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=form.recipient_id,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=form_data,
|
||||
submission_user_id=user.id if user else None,
|
||||
submission_end_user_id=None,
|
||||
)
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(recipient_query).first()
|
||||
self._enqueue_resume(result.workflow_run_id)
|
||||
|
||||
if recipient_model is None:
|
||||
def submit_form_by_token(
|
||||
self,
|
||||
recipient_type: RecipientType,
|
||||
form_token: str,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
submission_end_user_id: str | None = None,
|
||||
):
|
||||
form = self.get_form_by_token(form_token)
|
||||
if form is None or form.recipient_type != recipient_type:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
form_model = recipient_model.form
|
||||
form = Form(form_model)
|
||||
self._ensure_not_submitted(form)
|
||||
|
||||
result = self._form_repository.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=form.recipient_id,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=form_data,
|
||||
submission_user_id=None,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
|
||||
self._enqueue_resume(result.workflow_run_id)
|
||||
|
||||
def _ensure_not_submitted(self, form: Form) -> None:
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form_model.id)
|
||||
raise FormSubmittedError(form.id)
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model.selected_action_id = selected_action_id
|
||||
form_model.submitted_data = json.dumps(form_data)
|
||||
form_model.submitted_at = naive_utc_now()
|
||||
form_model.submission_user_id = user.id
|
||||
def _enqueue_resume(self, workflow_run_id: str) -> None:
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id)
|
||||
|
||||
form_model.completed_by_recipient_id = recipient_model.id
|
||||
session.add(form_model)
|
||||
if trigger_log is not None:
|
||||
payload = WorkflowResumeTaskData(
|
||||
workflow_trigger_log_id=trigger_log.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
try:
|
||||
resume_workflow_execution.apply_async(
|
||||
kwargs={"task_data_dict": payload.model_dump()},
|
||||
queue=trigger_log.queue_name,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
if self._enqueue_chatflow_resume(workflow_run_id):
|
||||
return
|
||||
|
||||
logger.warning("No workflow trigger log bound to workflow run %s; skipping resume dispatch", workflow_run_id)
|
||||
|
||||
def _enqueue_chatflow_resume(self, workflow_run_id: str) -> bool:
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is None:
|
||||
return False
|
||||
|
||||
app = session.get(App, workflow_run.app_id)
|
||||
|
||||
if app is None:
|
||||
return False
|
||||
|
||||
if app.mode != AppMode.ADVANCED_CHAT.value:
|
||||
return False
|
||||
|
||||
try:
|
||||
resume_chatflow_execution.apply_async(
|
||||
kwargs={"payload": {"workflow_run_id": workflow_run_id}},
|
||||
queue="chatflow_execute",
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to enqueue chatflow resume for workflow run %s", workflow_run_id)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -98,6 +98,13 @@ class WorkflowTaskData(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class WorkflowResumeTaskData(BaseModel):
|
||||
"""Payload for workflow resumption tasks."""
|
||||
|
||||
workflow_trigger_log_id: str
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
class AsyncTriggerExecutionResult(BaseModel):
|
||||
"""Result from async trigger-based workflow execution"""
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
from core.workflow.entities.workflow_pause import PauseDetail
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import (
|
||||
@ -160,9 +159,3 @@ class WorkflowRunService:
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=run_id,
|
||||
)
|
||||
|
||||
def get_pause_details(self, workflow_run_id: str) -> Sequence[PauseDetail]:
|
||||
pause = self._workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
if pause is None:
|
||||
return []
|
||||
return pause.get_pause_details()
|
||||
|
||||
@ -949,7 +949,7 @@ class WorkflowService:
|
||||
node_data = node.get("data", {})
|
||||
node_type = node_data.get("type")
|
||||
|
||||
if node_type == "human_input":
|
||||
if node_type == NodeType.HUMAN_INPUT:
|
||||
self._validate_human_input_node_data(node_data)
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from .workflow_execute_task import chatflow_execute_task
|
||||
|
||||
__all__ = [
|
||||
"chatflow_execute_taskch"
|
||||
]
|
||||
__all__ = ["chatflow_execute_task"]
|
||||
|
||||
@ -8,18 +8,21 @@ from typing import Annotated, Any, TypeAlias, Union
|
||||
from celery import shared_task
|
||||
from flask import current_app, json
|
||||
from pydantic import BaseModel, Discriminator, Field, Tag
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from extensions.ext_database import db
|
||||
from libs.flask_utils import set_login_user
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -129,6 +132,11 @@ class _ChatflowRunner:
|
||||
workflow = session.get(Workflow, exec_params.workflow_id)
|
||||
app = session.get(App, workflow.app_id)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=self._session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
user = self._resolve_user()
|
||||
|
||||
chat_generator = AdvancedChatAppGenerator()
|
||||
@ -144,6 +152,7 @@ class _ChatflowRunner:
|
||||
invoke_from=exec_params.invoke_from,
|
||||
streaming=exec_params.streaming,
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
if not exec_params.streaming:
|
||||
return response
|
||||
@ -174,11 +183,135 @@ class _ChatflowRunner:
|
||||
return user
|
||||
|
||||
|
||||
def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None:
|
||||
role = CreatorUserRole(workflow_run.created_by_role)
|
||||
if role == CreatorUserRole.ACCOUNT:
|
||||
user = session.get(Account, workflow_run.created_by)
|
||||
if user:
|
||||
user.set_tenant_id(workflow_run.tenant_id)
|
||||
return user
|
||||
|
||||
return session.get(EndUser, workflow_run.created_by)
|
||||
|
||||
|
||||
@shared_task(queue="chatflow_execute")
|
||||
def chatflow_execute_task(payload: str) -> Mapping[str, Any] | None:
|
||||
exec_params = ChatflowExecutionParams.model_validate_json(payload)
|
||||
|
||||
print("chatflow_execute_task run with params", exec_params)
|
||||
logger.info("chatflow_execute_task run with params: %s", exec_params)
|
||||
|
||||
runner = _ChatflowRunner(db.engine, exec_params=exec_params)
|
||||
return runner.run()
|
||||
|
||||
|
||||
@shared_task(queue="chatflow_execute", name="resume_chatflow_execution")
|
||||
def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
workflow_run_id = payload["workflow_run_id"]
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory)
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
if pause_entity is None:
|
||||
logger.warning("No pause entity found for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
try:
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
except Exception:
|
||||
logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if not isinstance(generate_entity, AdvancedChatAppGenerateEntity):
|
||||
logger.error(
|
||||
"Resumption entity is not AdvancedChatAppGenerateEntity for workflow run %s (found %s)",
|
||||
workflow_run_id,
|
||||
type(generate_entity),
|
||||
)
|
||||
return
|
||||
|
||||
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is None:
|
||||
logger.warning("Workflow run %s not found during resume", workflow_run_id)
|
||||
return
|
||||
|
||||
workflow = session.get(Workflow, workflow_run.workflow_id)
|
||||
if workflow is None:
|
||||
logger.warning("Workflow %s not found during resume", workflow_run.workflow_id)
|
||||
return
|
||||
|
||||
app_model = session.get(App, workflow_run.app_id)
|
||||
if app_model is None:
|
||||
logger.warning("App %s not found during resume", workflow_run.app_id)
|
||||
return
|
||||
|
||||
if generate_entity.conversation_id is None:
|
||||
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
conversation = session.get(Conversation, generate_entity.conversation_id)
|
||||
if conversation is None:
|
||||
logger.warning(
|
||||
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
|
||||
)
|
||||
return
|
||||
|
||||
message = session.scalar(
|
||||
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
|
||||
)
|
||||
if message is None:
|
||||
logger.warning("Message not found for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
user = _resolve_user_for_run(session, workflow_run)
|
||||
if user is None:
|
||||
logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id)
|
||||
return
|
||||
|
||||
workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
try:
|
||||
generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
application_generate_entity=generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
@ -5,6 +5,7 @@ These tasks handle workflow execution for different subscription tiers
|
||||
with appropriate retry policies and error handling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
@ -14,23 +15,37 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
||||
from core.app.layers.timeslice_layer import TimeSliceLayer
|
||||
from core.app.layers.trigger_post_layer import TriggerPostLayer
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser, Tenant
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import WorkflowNotFoundError
|
||||
from services.workflow.entities import (
|
||||
TriggerData,
|
||||
WorkflowResumeTaskData,
|
||||
WorkflowTaskData,
|
||||
)
|
||||
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler
|
||||
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRIGGER_TO_RUN_SOURCE = {
|
||||
AppTriggerType.TRIGGER_WEBHOOK: WorkflowRunTriggeredFrom.WEBHOOK,
|
||||
AppTriggerType.TRIGGER_SCHEDULE: WorkflowRunTriggeredFrom.SCHEDULE,
|
||||
AppTriggerType.TRIGGER_PLUGIN: WorkflowRunTriggeredFrom.PLUGIN,
|
||||
}
|
||||
|
||||
|
||||
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
|
||||
def execute_workflow_professional(task_data_dict: dict[str, Any]):
|
||||
@ -144,6 +159,11 @@ def _execute_workflow_common(
|
||||
if trigger_data.workflow_id:
|
||||
args["workflow_id"] = str(trigger_data.workflow_id)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
# Execute the workflow with the trigger type
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
@ -159,6 +179,7 @@ def _execute_workflow_common(
|
||||
# TODO: Re-enable TimeSliceLayer after the HITL release.
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
|
||||
],
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -176,6 +197,117 @@ def _execute_workflow_common(
|
||||
session.commit()
|
||||
|
||||
|
||||
@shared_task(name="resume_workflow_execution")
|
||||
def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
|
||||
"""Resume a paused workflow run via Celery."""
|
||||
task_data = WorkflowResumeTaskData.model_validate(task_data_dict)
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
|
||||
|
||||
with session_factory() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
|
||||
if not trigger_log:
|
||||
logger.warning("Trigger log not found for resumption: %s", task_data.workflow_trigger_log_id)
|
||||
return
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id)
|
||||
if pause_entity is None:
|
||||
logger.warning("No pause state for workflow run %s", task_data.workflow_run_id)
|
||||
return
|
||||
|
||||
try:
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id)
|
||||
raise exc
|
||||
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if not isinstance(generate_entity, WorkflowAppGenerateEntity):
|
||||
logger.error(
|
||||
"Unsupported resumption entity for workflow run %s: %s",
|
||||
task_data.workflow_run_id,
|
||||
type(generate_entity),
|
||||
)
|
||||
return
|
||||
|
||||
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
|
||||
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
|
||||
if workflow is None:
|
||||
raise WorkflowNotFoundError(f"Workflow not found: {trigger_log.workflow_id}")
|
||||
|
||||
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
|
||||
if app_model is None:
|
||||
raise WorkflowNotFoundError(f"App not found: {trigger_log.app_id}")
|
||||
|
||||
user = _get_user(session, trigger_log)
|
||||
|
||||
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
|
||||
queue=trigger_log.queue_name,
|
||||
schedule_strategy=AsyncWorkflowSystemStrategy,
|
||||
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
|
||||
)
|
||||
cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity)
|
||||
|
||||
try:
|
||||
trigger_type = AppTriggerType(trigger_log.trigger_type)
|
||||
except ValueError:
|
||||
trigger_type = AppTriggerType.UNKNOWN
|
||||
|
||||
triggered_from = _TRIGGER_TO_RUN_SOURCE.get(trigger_type, WorkflowRunTriggeredFrom.APP_RUN)
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=generate_entity.app_config.app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity)
|
||||
|
||||
trigger_log.status = WorkflowTriggerStatus.RUNNING
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=[
|
||||
TimeSliceLayer(cfs_plan_scheduler),
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
|
||||
],
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
except Exception as exc:
|
||||
trigger_log.status = WorkflowTriggerStatus.FAILED
|
||||
trigger_log.error = str(exc)
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
raise
|
||||
|
||||
|
||||
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
|
||||
"""Compose user from trigger log"""
|
||||
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
|
||||
|
||||
@ -0,0 +1,278 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
API_DIR = str(Path(__file__).resolve().parents[5])
|
||||
if API_DIR not in sys.path:
|
||||
sys.path.insert(0, API_DIR)
|
||||
|
||||
import core.workflow.nodes.human_input.entities # noqa: F401
|
||||
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.entities.commands import PauseCommand
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
if "core.ops.ops_trace_manager" not in sys.modules:
|
||||
ops_stub = ModuleType("core.ops.ops_trace_manager")
|
||||
|
||||
class _StubTraceQueueManager:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
ops_stub.TraceQueueManager = _StubTraceQueueManager
|
||||
sys.modules["core.ops.ops_trace_manager"] = ops_stub
|
||||
|
||||
|
||||
class _StubToolNodeData(BaseNodeData):
|
||||
pass
|
||||
|
||||
|
||||
class _StubToolNode(Node):
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
def init_node_data(self, data):
|
||||
self._node_data = _StubToolNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self):
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"value": f"{self.id}-done"},
|
||||
)
|
||||
|
||||
|
||||
def _patch_tool_node(mocker):
|
||||
from core.workflow.nodes import node_factory
|
||||
|
||||
custom_mapping = dict(node_factory.NODE_TYPE_CLASSES_MAPPING)
|
||||
custom_versions = dict(custom_mapping[NodeType.TOOL])
|
||||
custom_versions[node_factory.LATEST_VERSION] = _StubToolNode
|
||||
custom_mapping[NodeType.TOOL] = custom_versions
|
||||
mocker.patch("core.workflow.nodes.node_factory.NODE_TYPE_CLASSES_MAPPING", custom_mapping)
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="service-api",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": start_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_data.model_dump())
|
||||
|
||||
tool_data = _StubToolNodeData(title="tool")
|
||||
tool_a = _StubToolNode(
|
||||
id="tool_a",
|
||||
config={"id": "tool_a", "data": tool_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
tool_a.init_node_data(tool_data.model_dump())
|
||||
|
||||
tool_b = _StubToolNode(
|
||||
id="tool_b",
|
||||
config={"id": "tool_b", "data": tool_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
tool_b.init_node_data(tool_data.model_dump())
|
||||
|
||||
tool_c = _StubToolNode(
|
||||
id="tool_c",
|
||||
config={"id": "tool_c", "data": tool_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
tool_c.init_node_data(tool_data.model_dump())
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[VariableSelector(variable="result", value_selector=["tool_c", "value"])],
|
||||
desc=None,
|
||||
)
|
||||
end_node = EndNode(
|
||||
id="end",
|
||||
config={"id": "end", "data": end_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
end_node.init_node_data(end_data.model_dump())
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(tool_a)
|
||||
.add_node(tool_b)
|
||||
.add_node(tool_c)
|
||||
.add_node(end_node)
|
||||
.add_edge("tool_a", "tool_b")
|
||||
.add_edge("tool_b", "tool_c")
|
||||
.add_edge("tool_c", "end")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.system_variables.workflow_execution_id = run_id
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
|
||||
command_channel = InMemoryChannel()
|
||||
graph = _build_graph(runtime_state)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
events: list[GraphEngineEvent] = []
|
||||
for event in engine.run():
|
||||
if isinstance(event, NodeRunSucceededEvent) and pause_on and event.node_id == pause_on:
|
||||
command_channel.send_command(PauseCommand(reason="test pause"))
|
||||
engine._command_processor.process_commands() # type: ignore[attr-defined]
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
|
||||
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
|
||||
|
||||
|
||||
def test_workflow_app_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("paused-run")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = wf_app_gen_module.WorkflowAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
graph_runtime_state=resumed_state,
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
|
||||
|
||||
def test_advanced_chat_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("adv-baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("adv-paused")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = adv_app_gen_module.AdvancedChatAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv"),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
graph_runtime_state=resumed_state,
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
@ -61,8 +61,8 @@ class ConcurrentPublisher:
|
||||
messages.append(message)
|
||||
if self.delay > 0:
|
||||
time.sleep(self.delay)
|
||||
except Exception as e:
|
||||
_logger.error("Publisher %s error: %s", thread_id, e)
|
||||
except Exception:
|
||||
_logger.exception("Pubmsg=lisher %s", thread_id)
|
||||
|
||||
with self._lock:
|
||||
self.published_messages.append(messages)
|
||||
@ -308,8 +308,8 @@ def measure_throughput(
|
||||
try:
|
||||
operation()
|
||||
count += 1
|
||||
except Exception as e:
|
||||
_logger.error("Operation failed: %s", e)
|
||||
except Exception:
|
||||
_logger.exception("Operation failed")
|
||||
break
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
@ -1,4 +1,26 @@
|
||||
import sys
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
if "core.ops.ops_trace_manager" not in sys.modules:
|
||||
stub_module = ModuleType("core.ops.ops_trace_manager")
|
||||
|
||||
class _StubTraceQueueManager:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
stub_module.TraceQueueManager = _StubTraceQueueManager
|
||||
sys.modules["core.ops.ops_trace_manager"] = stub_module
|
||||
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_pause_resume_state import (
|
||||
_build_pausing_graph,
|
||||
_build_runtime_state,
|
||||
_node_successes,
|
||||
_PausingNode,
|
||||
_PausingNodeData,
|
||||
_run_graph,
|
||||
)
|
||||
|
||||
|
||||
def test_should_prepare_user_inputs_defaults_to_true():
|
||||
@ -17,3 +39,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
|
||||
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
|
||||
|
||||
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
|
||||
|
||||
def test_resume_delegates_to_generate(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")
|
||||
|
||||
application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger")
|
||||
runtime_state = MagicMock(name="runtime-state")
|
||||
pause_config = MagicMock(name="pause-config")
|
||||
|
||||
result = generator.resume(
|
||||
app_model=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
graph_runtime_state=runtime_state,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=("layer",),
|
||||
pause_state_config=pause_config,
|
||||
variable_loader=MagicMock(),
|
||||
)
|
||||
|
||||
assert result == "ok"
|
||||
mock_generate.assert_called_once()
|
||||
kwargs = mock_generate.call_args.kwargs
|
||||
assert kwargs["graph_runtime_state"] is runtime_state
|
||||
assert kwargs["pause_state_config"] is pause_config
|
||||
assert kwargs["streaming"] is False
|
||||
assert kwargs["invoke_from"] == "debugger"
|
||||
|
||||
|
||||
def test_generate_appends_pause_layer_and_forwards_state(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
mock_queue_manager = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager)
|
||||
|
||||
fake_current_app = MagicMock()
|
||||
fake_current_app._get_current_object.return_value = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
|
||||
return_value="converted",
|
||||
)
|
||||
mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock())
|
||||
|
||||
pause_layer = MagicMock(name="pause-layer")
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.PauseStatePersistenceLayer",
|
||||
return_value=pause_layer,
|
||||
)
|
||||
|
||||
dummy_session = MagicMock()
|
||||
dummy_session.close = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session)
|
||||
|
||||
worker_kwargs: dict[str, object] = {}
|
||||
|
||||
class DummyThread:
|
||||
def __init__(self, target, kwargs):
|
||||
worker_kwargs["target"] = target
|
||||
worker_kwargs["kwargs"] = kwargs
|
||||
|
||||
def start(self):
|
||||
return None
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread)
|
||||
|
||||
app_model = SimpleNamespace(mode="workflow")
|
||||
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf")
|
||||
application_generate_entity = SimpleNamespace(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from="service-api",
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
stream=True,
|
||||
workflow_execution_id="run",
|
||||
)
|
||||
|
||||
graph_runtime_state = MagicMock()
|
||||
|
||||
result = generator._generate(
|
||||
app_model=app_model,
|
||||
workflow=MagicMock(),
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from="service-api",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
graph_engine_layers=("base-layer",),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"),
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer)
|
||||
assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state
|
||||
|
||||
|
||||
def test_resume_path_runs_worker_with_runtime_state(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
runtime_state = MagicMock(name="runtime-state")
|
||||
|
||||
pause_layer = MagicMock(name="pause-layer")
|
||||
mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager)
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="raw-response")
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
|
||||
side_effect=lambda response, invoke_from: response,
|
||||
)
|
||||
|
||||
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
|
||||
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1"
|
||||
)
|
||||
end_user = SimpleNamespace(session_id="end-user-session")
|
||||
app_record = SimpleNamespace(id="app")
|
||||
|
||||
session = MagicMock()
|
||||
session.__enter__.return_value = session
|
||||
session.__exit__.return_value = False
|
||||
session.scalar.side_effect = [workflow, end_user, app_record]
|
||||
mocker.patch("core.app.apps.workflow.app_generator.Session", return_value=session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
|
||||
def runner_ctor(**kwargs):
|
||||
assert kwargs["graph_runtime_state"] is runtime_state
|
||||
return runner_instance
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor)
|
||||
|
||||
class ImmediateThread:
|
||||
def __init__(self, target, kwargs):
|
||||
target(**kwargs)
|
||||
|
||||
def start(self):
|
||||
return None
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner")
|
||||
|
||||
app_model = SimpleNamespace(mode="workflow")
|
||||
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow")
|
||||
application_generate_entity = SimpleNamespace(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from="service-api",
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
stream=True,
|
||||
workflow_execution_id="run",
|
||||
trace_manager=MagicMock(),
|
||||
)
|
||||
|
||||
result = generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
graph_runtime_state=runtime_state,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
|
||||
assert result == "raw-response"
|
||||
runner_instance.run.assert_called_once()
|
||||
queue_manager.graph_runtime_state = runtime_state
|
||||
|
||||
@ -1,317 +0,0 @@
|
||||
"""
|
||||
Tests for HumanInputForm domain model and repository.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from core.repositories.sqlalchemy_human_input_form_repository import SQLAlchemyHumanInputFormRepository
|
||||
from core.workflow.entities.human_input_form import HumanInputForm, HumanInputFormStatus
|
||||
|
||||
|
||||
class TestHumanInputForm:
|
||||
"""Test cases for HumanInputForm domain model."""
|
||||
|
||||
def test_create_form(self):
|
||||
"""Test creating a new form."""
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
|
||||
assert form.id_ == "test-form-id"
|
||||
assert form.workflow_run_id == "test-workflow-run"
|
||||
assert form.status == HumanInputFormStatus.WAITING
|
||||
assert form.can_be_submitted
|
||||
assert not form.is_submitted
|
||||
assert not form.is_expired
|
||||
assert form.is_waiting
|
||||
|
||||
def test_submit_form(self):
|
||||
"""Test submitting a form."""
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
|
||||
form.submit(
|
||||
data={"field1": "value1"},
|
||||
action="submit",
|
||||
submission_user_id="user123",
|
||||
)
|
||||
|
||||
assert form.is_submitted
|
||||
assert not form.can_be_submitted
|
||||
assert form.status == HumanInputFormStatus.SUBMITTED
|
||||
assert form.submission is not None
|
||||
assert form.submission.data == {"field1": "value1"}
|
||||
assert form.submission.action == "submit"
|
||||
assert form.submission.submission_user_id == "user123"
|
||||
|
||||
def test_submit_form_invalid_action(self):
|
||||
"""Test submitting a form with invalid action."""
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid action: invalid_action"):
|
||||
form.submit(data={}, action="invalid_action")
|
||||
|
||||
def test_submit_expired_form(self):
|
||||
"""Test submitting an expired form should fail."""
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
form.expire()
|
||||
|
||||
with pytest.raises(ValueError, match="Form cannot be submitted in status: expired"):
|
||||
form.submit(data={}, action="submit")
|
||||
|
||||
def test_expire_form(self):
|
||||
"""Test expiring a form."""
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
|
||||
form.expire()
|
||||
assert form.is_expired
|
||||
assert not form.can_be_submitted
|
||||
assert form.status == HumanInputFormStatus.EXPIRED
|
||||
|
||||
def test_expire_already_submitted_form(self):
|
||||
"""Test expiring an already submitted form should fail."""
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
form.submit(data={}, action="submit")
|
||||
|
||||
with pytest.raises(ValueError, match="Form cannot be expired in status: submitted"):
|
||||
form.expire()
|
||||
|
||||
def test_get_form_definition_for_display(self):
|
||||
"""Test getting form definition for display."""
|
||||
form_definition = {
|
||||
"inputs": [{"type": "text", "name": "field1"}],
|
||||
"user_actions": [{"id": "submit", "title": "Submit"}],
|
||||
}
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition=form_definition,
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
|
||||
result = form.get_form_definition_for_display()
|
||||
|
||||
assert result["form_content"] == "<form>Test form</form>"
|
||||
assert result["inputs"] == form_definition["inputs"]
|
||||
assert result["user_actions"] == form_definition["user_actions"]
|
||||
assert "site" not in result
|
||||
|
||||
def test_get_form_definition_for_display_with_site_info(self):
|
||||
"""Test getting form definition for display with site info."""
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition={"inputs": [], "user_actions": []},
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
|
||||
result = form.get_form_definition_for_display(include_site_info=True)
|
||||
|
||||
assert "site" in result
|
||||
assert result["site"]["title"] == "Workflow Form"
|
||||
|
||||
def test_get_form_definition_expired_form(self):
|
||||
"""Test getting form definition for expired form should fail."""
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition={"inputs": [], "user_actions": []},
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
form.expire()
|
||||
|
||||
with pytest.raises(ValueError, match="Form has expired"):
|
||||
form.get_form_definition_for_display()
|
||||
|
||||
def test_get_form_definition_submitted_form(self):
|
||||
"""Test getting form definition for submitted form should fail."""
|
||||
form = HumanInputForm.create(
|
||||
id_="test-form-id",
|
||||
workflow_run_id="test-workflow-run",
|
||||
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
|
||||
rendered_content="<form>Test form</form>",
|
||||
)
|
||||
form.submit(data={}, action="submit")
|
||||
|
||||
with pytest.raises(ValueError, match="Form has already been submitted"):
|
||||
form.get_form_definition_for_display()
|
||||
|
||||
|
||||
class TestSQLAlchemyHumanInputFormRepository:
|
||||
"""Test cases for SQLAlchemyHumanInputFormRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(self):
|
||||
"""Create a mock session factory."""
|
||||
session = MagicMock()
|
||||
session_factory = MagicMock()
|
||||
session_factory.return_value.__enter__.return_value = session
|
||||
session_factory.return_value.__exit__.return_value = None
|
||||
return session_factory
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create a mock user."""
|
||||
user = MagicMock()
|
||||
user.current_tenant_id = "test-tenant-id"
|
||||
user.id = "test-user-id"
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_session_factory, mock_user):
|
||||
"""Create a repository instance."""
|
||||
return SQLAlchemyHumanInputFormRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
)
|
||||
|
||||
def test_to_domain_model(self, repository):
|
||||
"""Test converting DB model to domain model."""
|
||||
from models.human_input import (
|
||||
HumanInputForm as DBForm,
|
||||
)
|
||||
from models.human_input import (
|
||||
HumanInputFormStatus as DBStatus,
|
||||
)
|
||||
from models.human_input import (
|
||||
HumanInputSubmissionType as DBSubmissionType,
|
||||
)
|
||||
|
||||
db_form = DBForm()
|
||||
db_form.id = "test-id"
|
||||
db_form.workflow_run_id = "test-workflow"
|
||||
db_form.form_definition = json.dumps({"inputs": [], "user_actions": []})
|
||||
db_form.rendered_content = "<form>Test</form>"
|
||||
db_form.status = DBStatus.WAITING
|
||||
db_form.web_app_token = "test-token"
|
||||
db_form.created_at = datetime.utcnow()
|
||||
db_form.submitted_data = json.dumps({"field": "value"})
|
||||
db_form.submitted_at = datetime.utcnow()
|
||||
db_form.submission_type = DBSubmissionType.web_form
|
||||
db_form.submission_user_id = "user123"
|
||||
|
||||
domain_form = repository._to_domain_model(db_form)
|
||||
|
||||
assert domain_form.id_ == "test-id"
|
||||
assert domain_form.workflow_run_id == "test-workflow"
|
||||
assert domain_form.form_definition == {"inputs": [], "user_actions": []}
|
||||
assert domain_form.rendered_content == "<form>Test</form>"
|
||||
assert domain_form.status == HumanInputFormStatus.WAITING
|
||||
assert domain_form.web_app_token == "test-token"
|
||||
assert domain_form.submission is not None
|
||||
assert domain_form.submission.data == {"field": "value"}
|
||||
assert domain_form.submission.submission_user_id == "user123"
|
||||
|
||||
def test_to_db_model(self, repository):
|
||||
"""Test converting domain model to DB model."""
|
||||
from models.human_input import (
|
||||
HumanInputFormStatus as DBStatus,
|
||||
)
|
||||
|
||||
domain_form = HumanInputForm.create(
|
||||
id_="test-id",
|
||||
workflow_run_id="test-workflow",
|
||||
form_definition={"inputs": [], "user_actions": []},
|
||||
rendered_content="<form>Test</form>",
|
||||
web_app_token="test-token",
|
||||
)
|
||||
|
||||
db_form = repository._to_db_model(domain_form)
|
||||
|
||||
assert db_form.id == "test-id"
|
||||
assert db_form.tenant_id == "test-tenant-id"
|
||||
assert db_form.app_id == "test-app-id"
|
||||
assert db_form.workflow_run_id == "test-workflow"
|
||||
assert json.loads(db_form.form_definition) == {"inputs": [], "user_actions": []}
|
||||
assert db_form.rendered_content == "<form>Test</form>"
|
||||
assert db_form.status == DBStatus.WAITING
|
||||
assert db_form.web_app_token == "test-token"
|
||||
|
||||
def test_save(self, repository, mock_session_factory):
|
||||
"""Test saving a form."""
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
domain_form = HumanInputForm.create(
|
||||
id_="test-id",
|
||||
workflow_run_id="test-workflow",
|
||||
form_definition={"inputs": []},
|
||||
rendered_content="<form>Test</form>",
|
||||
)
|
||||
|
||||
repository.save(domain_form)
|
||||
|
||||
session.merge.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_get_by_id(self, repository, mock_session_factory):
|
||||
"""Test getting a form by ID."""
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
mock_db_form = MagicMock()
|
||||
mock_db_form.id = "test-id"
|
||||
session.scalar.return_value = mock_db_form
|
||||
|
||||
with patch.object(repository, "_to_domain_model") as mock_convert:
|
||||
domain_form = HumanInputForm.create(
|
||||
id_="test-id",
|
||||
workflow_run_id="test-workflow",
|
||||
form_definition={"inputs": []},
|
||||
rendered_content="<form>Test</form>",
|
||||
)
|
||||
mock_convert.return_value = domain_form
|
||||
|
||||
result = repository.get_by_id("test-id")
|
||||
|
||||
assert result == domain_form
|
||||
session.scalar.assert_called_once()
|
||||
mock_convert.assert_called_once_with(mock_db_form)
|
||||
|
||||
def test_get_by_id_not_found(self, repository, mock_session_factory):
|
||||
"""Test getting a non-existent form by ID."""
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Human input form not found: test-id"):
|
||||
repository.get_by_id("test-id")
|
||||
|
||||
def test_mark_expired_forms(self, repository, mock_session_factory):
|
||||
"""Test marking expired forms."""
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
mock_forms = [MagicMock(), MagicMock(), MagicMock()]
|
||||
session.scalars.return_value.all.return_value = mock_forms
|
||||
|
||||
result = repository.mark_expired_forms(expiry_hours=24)
|
||||
|
||||
assert result == 3
|
||||
for form in mock_forms:
|
||||
assert hasattr(form, "status")
|
||||
session.commit.assert_called_once()
|
||||
@ -2,16 +2,27 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_reposotiry import (
|
||||
HumanInputFormReadRepository,
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormRepositoryImpl,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import ExternalRecipient, MemberRecipient
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
MemberRecipient,
|
||||
TimeoutUnit,
|
||||
UserAction,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
@ -41,6 +52,23 @@ def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleName
|
||||
return created
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Avoid SQLAlchemy mapper configuration in tests using fake sessions."""
|
||||
|
||||
class _FakeSelect:
|
||||
def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
|
||||
return self
|
||||
|
||||
def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
|
||||
return self
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.human_input_reposotiry.selectinload", lambda *args, **kwargs: "_loader_option"
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.human_input_reposotiry.select", lambda *args, **kwargs: _FakeSelect())
|
||||
|
||||
|
||||
class TestHumanInputFormRepositoryImplHelpers:
|
||||
def test_create_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
@ -125,3 +153,201 @@ class TestHumanInputFormRepositoryImplHelpers:
|
||||
assert len(recipients) == 2
|
||||
emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients}
|
||||
assert emails == {"member1@example.com", "member2@example.com"}
|
||||
|
||||
|
||||
def _make_form_definition() -> str:
|
||||
return FormDefinition(
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
rendered_content="<p>hello</p>",
|
||||
timeout=1,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
).model_dump_json()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyForm:
|
||||
id: str
|
||||
workflow_run_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
form_definition: str
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
selected_action_id: str | None = None
|
||||
submitted_data: str | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
completed_by_recipient_id: str | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyRecipient:
|
||||
id: str
|
||||
form_id: str
|
||||
recipient_type: RecipientType
|
||||
access_token: str
|
||||
form: _DummyForm | None = None
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scalars_result=None,
|
||||
forms: dict[str, _DummyForm] | None = None,
|
||||
recipients: dict[str, _DummyRecipient] | None = None,
|
||||
):
|
||||
self._scalars_result = scalars_result
|
||||
self.forms = forms or {}
|
||||
self.recipients = recipients or {}
|
||||
|
||||
def scalars(self, _query):
|
||||
return _FakeScalarResult(self._scalars_result)
|
||||
|
||||
def get(self, model_cls, obj_id): # type: ignore[no-untyped-def]
|
||||
if getattr(model_cls, "__name__", None) == "HumanInputForm":
|
||||
return self.forms.get(obj_id)
|
||||
if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient":
|
||||
return self.recipients.get(obj_id)
|
||||
return None
|
||||
|
||||
def add(self, _obj):
|
||||
return None
|
||||
|
||||
def flush(self):
|
||||
return None
|
||||
|
||||
def refresh(self, _obj):
|
||||
return None
|
||||
|
||||
def begin(self):
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _session_factory(session: _FakeSession):
|
||||
class _SessionContext:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
def _factory(*_args, **_kwargs):
|
||||
return _SessionContext()
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
class TestHumanInputFormReadRepository:
|
||||
def test_get_by_token_returns_record(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
access_token="token-123",
|
||||
form=form,
|
||||
)
|
||||
session = _FakeSession(scalars_result=recipient)
|
||||
repo = HumanInputFormReadRepository(_session_factory(session))
|
||||
|
||||
record = repo.get_by_token("token-123")
|
||||
|
||||
assert record is not None
|
||||
assert record.form_id == form.id
|
||||
assert record.recipient_type == RecipientType.WEBAPP
|
||||
assert record.submitted is False
|
||||
|
||||
def test_get_by_form_id_and_recipient_type_uses_recipient(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
access_token="token-123",
|
||||
form=form,
|
||||
)
|
||||
session = _FakeSession(scalars_result=recipient)
|
||||
repo = HumanInputFormReadRepository(_session_factory(session))
|
||||
|
||||
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.WEBAPP)
|
||||
|
||||
assert record is not None
|
||||
assert record.recipient_id == recipient.id
|
||||
assert record.access_token == recipient.access_token
|
||||
|
||||
def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch):
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_reposotiry.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=fixed_now,
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id="form-1",
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
access_token="token-123",
|
||||
)
|
||||
session = _FakeSession(
|
||||
forms={form.id: form},
|
||||
recipients={recipient.id: recipient},
|
||||
)
|
||||
repo = HumanInputFormReadRepository(_session_factory(session))
|
||||
|
||||
record: HumanInputFormRecord = repo.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=recipient.id,
|
||||
selected_action_id="approve",
|
||||
form_data={"field": "value"},
|
||||
submission_user_id="user-1",
|
||||
submission_end_user_id="end-user-1",
|
||||
)
|
||||
|
||||
assert form.selected_action_id == "approve"
|
||||
assert form.completed_by_recipient_id == recipient.id
|
||||
assert form.submission_user_id == "user-1"
|
||||
assert form.submission_end_user_id == "end-user-1"
|
||||
assert form.submitted_at == fixed_now
|
||||
assert record.submitted is True
|
||||
assert record.selected_action_id == "approve"
|
||||
assert record.submitted_data == {"field": "value"}
|
||||
|
||||
@ -28,10 +28,11 @@ class TestPauseReasonDiscriminator:
|
||||
{
|
||||
"reason": {
|
||||
"TYPE": "human_input_required",
|
||||
"human_input_form_id": "form_id",
|
||||
"form_id": "form_id",
|
||||
"form_content": "form_content",
|
||||
},
|
||||
},
|
||||
HumanInputRequired(human_input_form_id="form_id"),
|
||||
HumanInputRequired(form_id="form_id", form_content="form_content"),
|
||||
id="HumanInputRequired",
|
||||
),
|
||||
pytest.param(
|
||||
@ -56,7 +57,7 @@ class TestPauseReasonDiscriminator:
|
||||
@pytest.mark.parametrize(
|
||||
"reason",
|
||||
[
|
||||
HumanInputRequired(human_input_form_id="form_id"),
|
||||
HumanInputRequired(form_id="form_id", form_content="form_content"),
|
||||
SchedulingPause(message="Hold on"),
|
||||
],
|
||||
ids=lambda x: type(x).__name__,
|
||||
|
||||
@ -764,203 +764,3 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
|
||||
assert partial_event.outputs.get("answer") == "fallback response"
|
||||
|
||||
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)
|
||||
|
||||
|
||||
def test_suspend_and_resume():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"},
|
||||
"id": "1753041723554-source-1753041730748-target",
|
||||
"source": "1753041723554",
|
||||
"sourceHandle": "source",
|
||||
"target": "1753041730748",
|
||||
"targetHandle": "target",
|
||||
"type": "custom",
|
||||
"zIndex": 0,
|
||||
},
|
||||
{
|
||||
"data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"},
|
||||
"id": "1753041730748-true-answer-target",
|
||||
"source": "1753041730748",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer",
|
||||
"targetHandle": "target",
|
||||
"type": "custom",
|
||||
"zIndex": 0,
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"isInIteration": False,
|
||||
"isInLoop": False,
|
||||
"sourceType": "if-else",
|
||||
"targetType": "answer",
|
||||
},
|
||||
"id": "1753041730748-false-1753041952799-target",
|
||||
"source": "1753041730748",
|
||||
"sourceHandle": "false",
|
||||
"target": "1753041952799",
|
||||
"targetHandle": "target",
|
||||
"type": "custom",
|
||||
"zIndex": 0,
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []},
|
||||
"height": 54,
|
||||
"id": "1753041723554",
|
||||
"position": {"x": 32, "y": 282},
|
||||
"positionAbsolute": {"x": 32, "y": 282},
|
||||
"selected": False,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244,
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"cases": [
|
||||
{
|
||||
"case_id": "true",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d",
|
||||
"value": "a",
|
||||
"varType": "string",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"id": "true",
|
||||
"logical_operator": "and",
|
||||
}
|
||||
],
|
||||
"desc": "",
|
||||
"selected": False,
|
||||
"title": "IF/ELSE",
|
||||
"type": "if-else",
|
||||
},
|
||||
"height": 126,
|
||||
"id": "1753041730748",
|
||||
"position": {"x": 368, "y": 282},
|
||||
"positionAbsolute": {"x": 368, "y": 282},
|
||||
"selected": False,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244,
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "A",
|
||||
"desc": "",
|
||||
"selected": False,
|
||||
"title": "Answer A",
|
||||
"type": "answer",
|
||||
"variables": [],
|
||||
},
|
||||
"height": 102,
|
||||
"id": "answer",
|
||||
"position": {"x": 746, "y": 282},
|
||||
"positionAbsolute": {"x": 746, "y": 282},
|
||||
"selected": False,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244,
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "Else",
|
||||
"desc": "",
|
||||
"selected": False,
|
||||
"title": "Answer Else",
|
||||
"type": "answer",
|
||||
"variables": [],
|
||||
},
|
||||
"height": 102,
|
||||
"id": "1753041952799",
|
||||
"position": {"x": 746, "y": 426},
|
||||
"positionAbsolute": {"x": 746, "y": 426},
|
||||
"selected": True,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244,
|
||||
},
|
||||
],
|
||||
"viewport": {"x": -420, "y": -76.5, "zoom": 1},
|
||||
}
|
||||
graph = Graph.init(graph_config)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
files=[],
|
||||
query="hello",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={"uid": "takato"},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
_IF_ELSE_NODE_ID = "1753041730748"
|
||||
|
||||
def command_source(params: CommandParams) -> CommandTypes:
|
||||
# requires the engine to suspend before the execution
|
||||
# of If-Else node.
|
||||
if params.next_node.node_id == _IF_ELSE_NODE_ID:
|
||||
return SuspendCommand()
|
||||
else:
|
||||
return ContinueCommand()
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
graph_init_params=graph_init_params,
|
||||
command_source=command_source,
|
||||
)
|
||||
events = list(graph_engine.run())
|
||||
last_event = events[-1]
|
||||
assert isinstance(last_event, GraphRunSuspendedEvent)
|
||||
assert last_event.current_node_id == _IF_ELSE_NODE_ID
|
||||
state = graph_engine.save()
|
||||
assert state != ""
|
||||
|
||||
engine2 = GraphEngine.resume(
|
||||
state=state,
|
||||
graph=graph,
|
||||
)
|
||||
events = list(engine2.run())
|
||||
assert isinstance(events[-1], GraphRunSucceededEvent)
|
||||
node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)]
|
||||
assert node_run_succeeded_events
|
||||
start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"]
|
||||
assert not start_events
|
||||
ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID]
|
||||
assert ifelse_succeeded_events
|
||||
answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"]
|
||||
assert answer_else_events
|
||||
assert answer_else_events[0].route_node_state.node_run_result.outputs == {
|
||||
"answer": "Else",
|
||||
"files": ArrayFileSegment(value=[]),
|
||||
}
|
||||
|
||||
answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"]
|
||||
assert not answer_a_events
|
||||
|
||||
@ -17,8 +17,8 @@ from core.workflow.graph_events import (
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
|
||||
@ -16,8 +16,8 @@ from core.workflow.graph_events import (
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
|
||||
@ -0,0 +1,185 @@
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
import core.workflow.nodes.human_input.entities # noqa: F401
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
class _PausingNodeData(BaseNodeData):
|
||||
pass
|
||||
|
||||
|
||||
class _PausingNode(Node):
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = _PausingNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@staticmethod
|
||||
def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]:
|
||||
yield event
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
resumed_flag = self.graph_runtime_state.variable_pool.get((self.id, "resumed"))
|
||||
if resumed_flag is None:
|
||||
# mark as resumed and request pause
|
||||
self.graph_runtime_state.variable_pool.add((self.id, "resumed"), True)
|
||||
return self._pause_generator(PauseRequestedEvent(reason=SchedulingPause(message="test pause")))
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"value": "completed"},
|
||||
)
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="service-api",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": start_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_data.model_dump())
|
||||
|
||||
pause_data = _PausingNodeData(title="pausing")
|
||||
pause_node = _PausingNode(
|
||||
id="pausing",
|
||||
config={"id": "pausing", "data": pause_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
pause_node.init_node_data(pause_data.model_dump())
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[
|
||||
VariableSelector(variable="result", value_selector=["pausing", "value"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_node = EndNode(
|
||||
id="end",
|
||||
config={"id": "end", "data": end_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
end_node.init_node_data(end_data.model_dump())
|
||||
|
||||
return Graph.new().add_root(start_node).add_node(pause_node).add_node(end_node).build()
|
||||
|
||||
|
||||
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]:
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
return list(engine.run())
|
||||
|
||||
|
||||
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
|
||||
return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)]
|
||||
|
||||
|
||||
def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any:
|
||||
segment = variable_pool.get(selector)
|
||||
assert segment is not None
|
||||
return getattr(segment, "value", segment)
|
||||
|
||||
|
||||
def test_engine_resume_restores_state_and_completion():
|
||||
# Baseline run without pausing
|
||||
baseline_state = _build_runtime_state()
|
||||
baseline_graph = _build_pausing_graph(baseline_state)
|
||||
baseline_state.variable_pool.add(("pausing", "resumed"), True)
|
||||
baseline_events = _run_graph(baseline_graph, baseline_state)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_success_nodes = _node_successes(baseline_events)
|
||||
|
||||
# Run with pause
|
||||
paused_state = _build_runtime_state()
|
||||
paused_graph = _build_pausing_graph(paused_state)
|
||||
paused_events = _run_graph(paused_graph, paused_state)
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
# Resume from snapshot
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resumed_graph = _build_pausing_graph(resumed_state)
|
||||
resumed_events = _run_graph(resumed_graph, resumed_state)
|
||||
assert isinstance(resumed_events[-1], GraphRunSucceededEvent)
|
||||
|
||||
combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events)
|
||||
assert combined_success_nodes == baseline_success_nodes
|
||||
|
||||
assert baseline_state.outputs == resumed_state.outputs
|
||||
assert _segment_value(baseline_state.variable_pool, ("pausing", "resumed")) == _segment_value(
|
||||
resumed_state.variable_pool, ("pausing", "resumed")
|
||||
)
|
||||
assert _segment_value(baseline_state.variable_pool, ("pausing", "value")) == _segment_value(
|
||||
resumed_state.variable_pool, ("pausing", "value")
|
||||
)
|
||||
assert baseline_state.graph_execution.completed
|
||||
assert resumed_state.graph_execution.completed
|
||||
@ -1,283 +0,0 @@
|
||||
"""
|
||||
Unit tests for human input node implementation.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.human_input import HumanInputNode, HumanInputNodeData
|
||||
|
||||
|
||||
class TestHumanInputNode:
|
||||
"""Test HumanInputNode implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_init_params(self):
|
||||
"""Create mock graph initialization parameters."""
|
||||
mock_params = Mock()
|
||||
mock_params.tenant_id = "tenant-123"
|
||||
mock_params.app_id = "app-456"
|
||||
mock_params.user_id = "user-789"
|
||||
mock_params.user_from = "web"
|
||||
mock_params.invoke_from = "web_app"
|
||||
mock_params.call_depth = 0
|
||||
return mock_params
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph(self):
|
||||
"""Create mock graph."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state(self):
|
||||
"""Create mock graph runtime state."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_node_config(self):
|
||||
"""Create sample node configuration."""
|
||||
return {
|
||||
"id": "human_input_123",
|
||||
"data": {
|
||||
"title": "User Confirmation",
|
||||
"desc": "Please confirm the action",
|
||||
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
|
||||
"form_content": "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}",
|
||||
"inputs": [
|
||||
{
|
||||
"type": "text-input",
|
||||
"output_variable_name": "confirmation",
|
||||
"placeholder": {"type": "constant", "value": "Type 'yes' to confirm"},
|
||||
}
|
||||
],
|
||||
"user_actions": [
|
||||
{"id": "confirm", "title": "Confirm", "button_style": "primary"},
|
||||
{"id": "cancel", "title": "Cancel", "button_style": "default"},
|
||||
],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "hour",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def human_input_node(self, sample_node_config, mock_graph_init_params, mock_graph, mock_graph_runtime_state):
|
||||
"""Create HumanInputNode instance."""
|
||||
node = HumanInputNode(
|
||||
id="node_123",
|
||||
config=sample_node_config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
return node
|
||||
|
||||
def test_node_initialization(self, human_input_node):
|
||||
"""Test node initialization."""
|
||||
assert human_input_node.node_id == "human_input_123"
|
||||
assert human_input_node.tenant_id == "tenant-123"
|
||||
assert human_input_node.app_id == "app-456"
|
||||
assert isinstance(human_input_node.node_data, HumanInputNodeData)
|
||||
assert human_input_node.node_data.title == "User Confirmation"
|
||||
|
||||
def test_node_type_and_version(self, human_input_node):
|
||||
"""Test node type and version."""
|
||||
assert human_input_node.type_.value == "human_input"
|
||||
assert human_input_node.version() == "1"
|
||||
|
||||
def test_node_properties(self, human_input_node):
|
||||
"""Test node properties access."""
|
||||
assert human_input_node.title == "User Confirmation"
|
||||
assert human_input_node.description == "Please confirm the action"
|
||||
assert human_input_node.error_strategy is None
|
||||
assert human_input_node.retry_config.retry_enabled is False
|
||||
|
||||
@patch("uuid.uuid4")
|
||||
def test_node_run_success(self, mock_uuid, human_input_node):
|
||||
"""Test successful node execution."""
|
||||
# Setup mocks
|
||||
mock_form_id = uuid.UUID("12345678-1234-5678-9abc-123456789012")
|
||||
mock_token = uuid.UUID("87654321-4321-8765-cba9-876543210987")
|
||||
mock_uuid.side_effect = [mock_form_id, mock_token]
|
||||
|
||||
# Execute the node
|
||||
result = human_input_node._run()
|
||||
|
||||
# Verify result
|
||||
assert result.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
assert result.metadata["suspended"] is True
|
||||
assert result.metadata["form_id"] == str(mock_form_id)
|
||||
assert result.metadata["web_app_form_token"] == str(mock_token).replace("-", "")
|
||||
|
||||
# Verify event data in metadata
|
||||
human_input_event = result.metadata["human_input_event"]
|
||||
assert human_input_event["form_id"] == str(mock_form_id)
|
||||
assert human_input_event["node_id"] == "human_input_123"
|
||||
assert human_input_event["form_content"] == "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}"
|
||||
assert len(human_input_event["inputs"]) == 1
|
||||
|
||||
suspended_event = result.metadata["suspended_event"]
|
||||
assert suspended_event["suspended_at_node_ids"] == ["human_input_123"]
|
||||
|
||||
def test_node_run_without_webapp_delivery(self, human_input_node):
|
||||
"""Test node execution without webapp delivery method."""
|
||||
# Modify node data to disable webapp delivery
|
||||
human_input_node.node_data.delivery_methods[0].enabled = False
|
||||
|
||||
result = human_input_node._run()
|
||||
|
||||
# Should still work, but without web app token
|
||||
assert result.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
assert result.metadata["web_app_form_token"] is None
|
||||
|
||||
def test_resume_from_human_input_success(self, human_input_node):
|
||||
"""Test successful resume from human input."""
|
||||
form_submission_data = {"inputs": {"confirmation": "yes"}, "action": "confirm"}
|
||||
|
||||
result = human_input_node.resume_from_human_input(form_submission_data)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["confirmation"] == "yes"
|
||||
assert result.outputs["_action"] == "confirm"
|
||||
assert result.metadata["form_submitted"] is True
|
||||
assert result.metadata["submitted_action"] == "confirm"
|
||||
|
||||
def test_resume_from_human_input_partial_inputs(self, human_input_node):
|
||||
"""Test resume with partial inputs."""
|
||||
form_submission_data = {
|
||||
"inputs": {}, # Empty inputs
|
||||
"action": "cancel",
|
||||
}
|
||||
|
||||
result = human_input_node.resume_from_human_input(form_submission_data)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "confirmation" not in result.outputs # Field not provided
|
||||
assert result.outputs["_action"] == "cancel"
|
||||
|
||||
def test_resume_from_human_input_missing_data(self, human_input_node):
|
||||
"""Test resume with missing submission data."""
|
||||
form_submission_data = {} # Missing required fields
|
||||
|
||||
result = human_input_node.resume_from_human_input(form_submission_data)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["_action"] == "" # Default empty action
|
||||
|
||||
def test_get_default_config(self):
|
||||
"""Test getting default configuration."""
|
||||
config = HumanInputNode.get_default_config()
|
||||
|
||||
assert config["type"] == "human_input"
|
||||
assert "config" in config
|
||||
config_data = config["config"]
|
||||
|
||||
assert len(config_data["delivery_methods"]) == 1
|
||||
assert config_data["delivery_methods"][0]["type"] == "webapp"
|
||||
assert config_data["delivery_methods"][0]["enabled"] is True
|
||||
|
||||
assert config_data["form_content"] == "# Human Input\n\nPlease provide your input:\n\n{{#$output.input#}}"
|
||||
assert len(config_data["inputs"]) == 1
|
||||
assert config_data["inputs"][0]["output_variable_name"] == "input"
|
||||
|
||||
assert len(config_data["user_actions"]) == 1
|
||||
assert config_data["user_actions"][0]["id"] == "submit"
|
||||
|
||||
assert config_data["timeout"] == 24
|
||||
assert config_data["timeout_unit"] == "hour"
|
||||
|
||||
def test_process_form_content(self, human_input_node):
|
||||
"""Test form content processing."""
|
||||
# This is a placeholder test since the actual variable substitution
|
||||
# logic is marked as TODO in the implementation
|
||||
processed_content = human_input_node._process_form_content()
|
||||
|
||||
# For now, should return the raw content
|
||||
expected_content = "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}"
|
||||
assert processed_content == expected_content
|
||||
|
||||
def test_extract_variable_selector_mapping(self):
|
||||
"""Test variable selector extraction."""
|
||||
graph_config = {}
|
||||
node_data = {
|
||||
"form_content": "Hello {{#node_123.output#}}",
|
||||
"inputs": [
|
||||
{
|
||||
"type": "text-input",
|
||||
"output_variable_name": "test",
|
||||
"placeholder": {"type": "variable", "selector": ["node_456", "var_name"]},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# This is a placeholder test since the actual extraction logic
|
||||
# is marked as TODO in the implementation
|
||||
mapping = HumanInputNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id="test_node", node_data=node_data
|
||||
)
|
||||
|
||||
# For now, should return empty dict
|
||||
assert mapping == {}
|
||||
|
||||
|
||||
class TestHumanInputNodeValidation:
|
||||
"""Test validation scenarios for HumanInputNode."""
|
||||
|
||||
def test_node_with_invalid_config(self):
|
||||
"""Test node creation with invalid configuration."""
|
||||
invalid_config = {
|
||||
"id": "test_node",
|
||||
"data": {
|
||||
"title": "Test",
|
||||
"delivery_methods": [
|
||||
{
|
||||
"type": "invalid_type", # Invalid delivery method type
|
||||
"enabled": True,
|
||||
"config": {},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
mock_params = Mock()
|
||||
mock_params.tenant_id = "tenant-123"
|
||||
mock_params.app_id = "app-456"
|
||||
mock_params.user_id = "user-789"
|
||||
mock_params.user_from = "web"
|
||||
mock_params.invoke_from = "web_app"
|
||||
mock_params.call_depth = 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
HumanInputNode(
|
||||
id="node_123",
|
||||
config=invalid_config,
|
||||
graph_init_params=mock_params,
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
)
|
||||
|
||||
def test_node_with_missing_node_id(self):
|
||||
"""Test node creation with missing node ID in config."""
|
||||
invalid_config = {
|
||||
# Missing "id" field
|
||||
"data": {"title": "Test"}
|
||||
}
|
||||
|
||||
mock_params = Mock()
|
||||
mock_params.tenant_id = "tenant-123"
|
||||
mock_params.app_id = "app-456"
|
||||
mock_params.user_id = "user-789"
|
||||
mock_params.user_from = "web"
|
||||
mock_params.invoke_from = "web_app"
|
||||
mock_params.call_depth = 0
|
||||
|
||||
with pytest.raises(ValueError, match="Node ID is required"):
|
||||
HumanInputNode(
|
||||
id="node_123",
|
||||
config=invalid_config,
|
||||
graph_init_params=mock_params,
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
)
|
||||
@ -1 +1 @@
|
||||
# Unit tests for human input library
|
||||
# Treat this directory as a package so support modules can be imported relatively.
|
||||
|
||||
248
api/tests/unit_tests/libs/_human_input/support.py
Normal file
248
api/tests/unit_tests/libs/_human_input/support.py
Normal file
@ -0,0 +1,248 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput, TimeoutUnit
|
||||
|
||||
|
||||
# Exceptions
|
||||
class HumanInputError(Exception):
|
||||
error_code: str = "unknown"
|
||||
|
||||
def __init__(self, message: str = "", error_code: str | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message or self.__class__.__name__
|
||||
if error_code:
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError):
|
||||
error_code = "form_not_found"
|
||||
|
||||
|
||||
class FormExpiredError(HumanInputError):
|
||||
error_code = "human_input_form_expired"
|
||||
|
||||
|
||||
class FormAlreadySubmittedError(HumanInputError):
|
||||
error_code = "human_input_form_submitted"
|
||||
|
||||
|
||||
class InvalidFormDataError(HumanInputError):
|
||||
error_code = "invalid_form_data"
|
||||
|
||||
|
||||
# Models
|
||||
@dataclass
|
||||
class HumanInputForm:
|
||||
form_id: str
|
||||
workflow_run_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str | None
|
||||
form_content: str
|
||||
inputs: list[FormInput]
|
||||
user_actions: list[dict[str, Any]]
|
||||
timeout: int
|
||||
timeout_unit: TimeoutUnit
|
||||
web_app_form_token: str | None = None
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submitted_data: dict[str, Any] | None = None
|
||||
submitted_action: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.expires_at is None:
|
||||
self.calculate_expiration()
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return self.expires_at is not None and datetime.utcnow() > self.expires_at
|
||||
|
||||
@property
|
||||
def is_submitted(self) -> bool:
|
||||
return self.submitted_at is not None
|
||||
|
||||
def mark_submitted(self, inputs: dict[str, Any], action: str) -> None:
|
||||
self.submitted_data = inputs
|
||||
self.submitted_action = action
|
||||
self.submitted_at = datetime.utcnow()
|
||||
|
||||
def submit(self, inputs: dict[str, Any], action: str) -> None:
|
||||
self.mark_submitted(inputs, action)
|
||||
|
||||
def calculate_expiration(self) -> None:
|
||||
start = self.created_at
|
||||
if self.timeout_unit == TimeoutUnit.HOUR:
|
||||
self.expires_at = start + timedelta(hours=self.timeout)
|
||||
elif self.timeout_unit == TimeoutUnit.DAY:
|
||||
self.expires_at = start + timedelta(days=self.timeout)
|
||||
else:
|
||||
raise ValueError(f"Unsupported timeout unit {self.timeout_unit}")
|
||||
|
||||
def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]:
|
||||
inputs_response = [
|
||||
{
|
||||
"type": form_input.type.name.lower().replace("_", "-"),
|
||||
"output_variable_name": form_input.output_variable_name,
|
||||
}
|
||||
for form_input in self.inputs
|
||||
]
|
||||
response = {
|
||||
"form_content": self.form_content,
|
||||
"inputs": inputs_response,
|
||||
"user_actions": self.user_actions,
|
||||
}
|
||||
if include_site_info:
|
||||
response["site"] = {"app_id": self.app_id, "title": "Workflow Form"}
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormSubmissionData:
|
||||
form_id: str
|
||||
inputs: dict[str, Any]
|
||||
action: str
|
||||
submitted_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore
|
||||
return cls(form_id=form_id, inputs=request.inputs, action=request.action)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormSubmissionRequest:
|
||||
inputs: dict[str, Any]
|
||||
action: str
|
||||
|
||||
|
||||
# Repository
|
||||
class InMemoryFormRepository:
|
||||
"""
|
||||
Simple in-memory repository used by unit tests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._forms: dict[str, HumanInputForm] = {}
|
||||
|
||||
@property
|
||||
def forms(self) -> dict[str, HumanInputForm]:
|
||||
return self._forms
|
||||
|
||||
def save(self, form: HumanInputForm) -> None:
|
||||
self._forms[form.form_id] = form
|
||||
|
||||
def get_by_id(self, form_id: str) -> Optional[HumanInputForm]:
|
||||
return self._forms.get(form_id)
|
||||
|
||||
def get_by_token(self, token: str) -> Optional[HumanInputForm]:
|
||||
for form in self._forms.values():
|
||||
if form.web_app_form_token == token:
|
||||
return form
|
||||
return None
|
||||
|
||||
def delete(self, form_id: str) -> None:
|
||||
self._forms.pop(form_id, None)
|
||||
|
||||
|
||||
# Service
|
||||
class FormService:
|
||||
"""Service layer for managing human input forms in tests."""
|
||||
|
||||
def __init__(self, repository: InMemoryFormRepository):
|
||||
self.repository = repository
|
||||
|
||||
def create_form(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
workflow_run_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str | None,
|
||||
form_content: str,
|
||||
inputs,
|
||||
user_actions,
|
||||
timeout: int,
|
||||
timeout_unit: TimeoutUnit,
|
||||
web_app_form_token: str | None = None,
|
||||
) -> HumanInputForm:
|
||||
form = HumanInputForm(
|
||||
form_id=form_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=node_id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
form_content=form_content,
|
||||
inputs=list(inputs),
|
||||
user_actions=[{"id": action.id, "title": action.title} for action in user_actions],
|
||||
timeout=timeout,
|
||||
timeout_unit=timeout_unit,
|
||||
web_app_form_token=web_app_form_token,
|
||||
)
|
||||
form.calculate_expiration()
|
||||
self.repository.save(form)
|
||||
return form
|
||||
|
||||
def get_form_by_id(self, form_id: str) -> HumanInputForm:
|
||||
form = self.repository.get_by_id(form_id)
|
||||
if form is None:
|
||||
raise FormNotFoundError()
|
||||
return form
|
||||
|
||||
def get_form_by_token(self, token: str) -> HumanInputForm:
|
||||
form = self.repository.get_by_token(token)
|
||||
if form is None:
|
||||
raise FormNotFoundError()
|
||||
return form
|
||||
|
||||
def get_form_definition(self, form_id: str, *, is_token: bool) -> dict:
|
||||
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
|
||||
if form.is_expired:
|
||||
raise FormExpiredError()
|
||||
if form.is_submitted:
|
||||
raise FormAlreadySubmittedError()
|
||||
|
||||
definition = {
|
||||
"form_content": form.form_content,
|
||||
"inputs": form.inputs,
|
||||
"user_actions": form.user_actions,
|
||||
}
|
||||
if is_token:
|
||||
definition["site"] = {"title": "Workflow Form"}
|
||||
return definition
|
||||
|
||||
def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None:
|
||||
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
|
||||
if form.is_expired:
|
||||
raise FormExpiredError()
|
||||
if form.is_submitted:
|
||||
raise FormAlreadySubmittedError()
|
||||
|
||||
self._validate_submission(form=form, submission_data=submission_data)
|
||||
form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action)
|
||||
self.repository.save(form)
|
||||
|
||||
def cleanup_expired_forms(self) -> int:
|
||||
expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired]
|
||||
for form_id in expired_ids:
|
||||
self.repository.delete(form_id)
|
||||
return len(expired_ids)
|
||||
|
||||
def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None:
|
||||
defined_actions = {action["id"] for action in form.user_actions}
|
||||
if submission_data.action not in defined_actions:
|
||||
raise InvalidFormDataError(f"Invalid action: {submission_data.action}")
|
||||
|
||||
missing_inputs = []
|
||||
for form_input in form.inputs:
|
||||
if form_input.output_variable_name not in submission_data.inputs:
|
||||
missing_inputs.append(form_input.output_variable_name)
|
||||
|
||||
if missing_inputs:
|
||||
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
|
||||
|
||||
# Extra inputs are allowed; no further validation required.
|
||||
@ -5,15 +5,6 @@ Unit tests for FormService.
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from libs._human_input.exceptions import (
|
||||
FormAlreadySubmittedError,
|
||||
FormExpiredError,
|
||||
FormNotFoundError,
|
||||
InvalidFormDataError,
|
||||
)
|
||||
from libs._human_input.form_service import FormService
|
||||
from libs._human_input.models import FormSubmissionData
|
||||
from libs._human_input.repository import InMemoryFormRepository
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormInput,
|
||||
@ -22,6 +13,16 @@ from core.workflow.nodes.human_input.entities import (
|
||||
UserAction,
|
||||
)
|
||||
|
||||
from .support import (
|
||||
FormAlreadySubmittedError,
|
||||
FormExpiredError,
|
||||
FormNotFoundError,
|
||||
FormService,
|
||||
FormSubmissionData,
|
||||
InMemoryFormRepository,
|
||||
InvalidFormDataError,
|
||||
)
|
||||
|
||||
|
||||
class TestFormService:
|
||||
"""Test FormService functionality."""
|
||||
|
||||
@ -5,16 +5,16 @@ Unit tests for human input form models.
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from libs._human_input.models import FormSubmissionData, HumanInputForm
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormInput,
|
||||
FormInputType,
|
||||
FormSubmissionRequest,
|
||||
TimeoutUnit,
|
||||
UserAction,
|
||||
)
|
||||
|
||||
from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm
|
||||
|
||||
|
||||
class TestHumanInputForm:
|
||||
"""Test HumanInputForm model."""
|
||||
|
||||
205
api/tests/unit_tests/services/test_human_input_service.py
Normal file
205
api/tests/unit_tests/services/test_human_input_service.py
Normal file
@ -0,0 +1,205 @@
|
||||
import dataclasses
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, TimeoutUnit, UserAction
|
||||
from models.account import Account
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import FormSubmittedError, HumanInputService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = None
|
||||
|
||||
factory = MagicMock()
|
||||
factory.return_value = session_cm
|
||||
return factory, session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_form_record():
|
||||
return HumanInputFormRecord(
|
||||
form_id="form-id",
|
||||
workflow_run_id="workflow-run-id",
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
definition=FormDefinition(
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
rendered_content="<p>hello</p>",
|
||||
timeout=1,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=datetime(2024, 1, 1),
|
||||
selected_action_id=None,
|
||||
submitted_data=None,
|
||||
submitted_at=None,
|
||||
submission_user_id=None,
|
||||
submission_end_user_id=None,
|
||||
completed_by_recipient_id=None,
|
||||
recipient_id="recipient-id",
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
access_token="token",
|
||||
)
|
||||
|
||||
|
||||
def test_enqueue_resume_dispatches_task(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
trigger_log = MagicMock()
|
||||
trigger_log.id = "trigger-log-id"
|
||||
trigger_log.queue_name = "workflow_queue"
|
||||
|
||||
repo_cls = mocker.patch(
|
||||
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
autospec=True,
|
||||
)
|
||||
repo = repo_cls.return_value
|
||||
repo.get_by_workflow_run_id.return_value = trigger_log
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
repo_cls.assert_called_once_with(session)
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == "workflow_queue"
|
||||
payload = call_kwargs["kwargs"]["task_data_dict"]
|
||||
assert payload["workflow_trigger_log_id"] == "trigger-log-id"
|
||||
assert payload["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_enqueue_resume_no_trigger_log(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
repo_cls = mocker.patch(
|
||||
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
autospec=True,
|
||||
)
|
||||
repo = repo_cls.return_value
|
||||
repo.get_by_workflow_run_id.return_value = None
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
repo_cls.assert_called_once_with(session)
|
||||
resume_task.apply_async.assert_not_called()
|
||||
|
||||
|
||||
def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
repo_cls = mocker.patch(
|
||||
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
autospec=True,
|
||||
)
|
||||
repo = repo_cls.return_value
|
||||
repo.get_by_workflow_run_id.return_value = None
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
app = MagicMock()
|
||||
app.mode = "advanced-chat"
|
||||
|
||||
session.get.side_effect = [workflow_run, app]
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_chatflow_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == "chatflow_execute"
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormReadRepository)
|
||||
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
form = service.get_form_definition_by_id("form-id")
|
||||
|
||||
repo.get_by_form_id_and_recipient_type.assert_called_once_with(
|
||||
form_id="form-id",
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
)
|
||||
assert form is not None
|
||||
assert form.get_definition() == sample_form_record.definition
|
||||
|
||||
|
||||
def test_get_form_definition_by_id_raises_on_submitted(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime(2024, 1, 1))
|
||||
repo = MagicMock(spec=HumanInputFormReadRepository)
|
||||
repo.get_by_form_id_and_recipient_type.return_value = submitted_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
|
||||
with pytest.raises(FormSubmittedError):
|
||||
service.get_form_definition_by_id("form-id")
|
||||
|
||||
|
||||
def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormReadRepository)
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
repo.mark_submitted.return_value = sample_form_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "_enqueue_resume")
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
form_token="token",
|
||||
selected_action_id="approve",
|
||||
form_data={"field": "value"},
|
||||
submission_end_user_id="end-user-id",
|
||||
)
|
||||
|
||||
repo.get_by_token.assert_called_once_with("token")
|
||||
repo.mark_submitted.assert_called_once()
|
||||
call_kwargs = repo.mark_submitted.call_args.kwargs
|
||||
assert call_kwargs["form_id"] == sample_form_record.form_id
|
||||
assert call_kwargs["recipient_id"] == sample_form_record.recipient_id
|
||||
assert call_kwargs["selected_action_id"] == "approve"
|
||||
assert call_kwargs["form_data"] == {"field": "value"}
|
||||
assert call_kwargs["submission_end_user_id"] == "end-user-id"
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
|
||||
|
||||
def test_submit_form_by_id_passes_account(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormReadRepository)
|
||||
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
|
||||
repo.mark_submitted.return_value = sample_form_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "_enqueue_resume")
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "account-id"
|
||||
|
||||
service.submit_form_by_id(
|
||||
form_id="form-id",
|
||||
selected_action_id="approve",
|
||||
form_data={"x": 1},
|
||||
user=account,
|
||||
)
|
||||
|
||||
repo.get_by_form_id_and_recipient_type.assert_called_once()
|
||||
repo.mark_submitted.assert_called_once()
|
||||
assert repo.mark_submitted.call_args.kwargs["submission_user_id"] == "account-id"
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
@ -35,7 +35,6 @@ class TestDataFactory:
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
status: str | WorkflowExecutionStatus = "paused",
|
||||
pause_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRun object."""
|
||||
@ -45,7 +44,6 @@ class TestDataFactory:
|
||||
mock_run.app_id = app_id
|
||||
mock_run.workflow_id = workflow_id
|
||||
mock_run.status = status
|
||||
mock_run.pause_id = pause_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_run, key, value)
|
||||
|
||||
@ -161,292 +161,3 @@ class TestWorkflowService:
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
|
||||
class TestWorkflowServiceHumanInputValidation:
|
||||
@pytest.fixture
|
||||
def workflow_service(self):
|
||||
# Mock sessionmaker to avoid database dependency
|
||||
mock_session_maker = MagicMock()
|
||||
return WorkflowService(mock_session_maker)
|
||||
|
||||
def test_validate_graph_structure_valid_human_input(self, workflow_service):
|
||||
"""Test validation of valid HumanInput node data."""
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
"title": "Human Input",
|
||||
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
|
||||
"form_content": "Please provide your input",
|
||||
"inputs": [
|
||||
{
|
||||
"type": "text-input",
|
||||
"output_variable_name": "user_input",
|
||||
"placeholder": {"type": "constant", "value": "Enter text here"},
|
||||
}
|
||||
],
|
||||
"user_actions": [{"id": "submit", "title": "Submit", "button_style": "primary"}],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "hour",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Should not raise any exception
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_graph_structure_empty_graph(self, workflow_service):
|
||||
"""Test validation of empty graph."""
|
||||
graph = {}
|
||||
|
||||
# Should not raise any exception
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_graph_structure_no_nodes(self, workflow_service):
|
||||
"""Test validation of graph with no nodes."""
|
||||
graph = {"nodes": []}
|
||||
|
||||
# Should not raise any exception
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_graph_structure_non_human_input_node(self, workflow_service):
|
||||
"""Test validation ignores non-HumanInput nodes."""
|
||||
graph = {"nodes": [{"id": "node-1", "data": {"type": "start", "title": "Start"}}]}
|
||||
|
||||
# Should not raise any exception
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_human_input_node_data_invalid_delivery_method_type(self, workflow_service):
|
||||
"""Test validation fails with invalid delivery method type."""
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
"title": "Human Input",
|
||||
"delivery_methods": [{"type": "invalid_type", "enabled": True, "config": {}}],
|
||||
"form_content": "Please provide your input",
|
||||
"inputs": [],
|
||||
"user_actions": [],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "hour",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_human_input_node_data_invalid_form_input_type(self, workflow_service):
|
||||
"""Test validation fails with invalid form input type."""
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
"title": "Human Input",
|
||||
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
|
||||
"form_content": "Please provide your input",
|
||||
"inputs": [{"type": "invalid-input-type", "output_variable_name": "user_input"}],
|
||||
"user_actions": [],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "hour",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_human_input_node_data_missing_required_fields(self, workflow_service):
|
||||
"""Test validation fails with missing required fields."""
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
# Missing required fields like title
|
||||
"delivery_methods": [],
|
||||
"form_content": "",
|
||||
"inputs": [],
|
||||
"user_actions": [],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_human_input_node_data_invalid_timeout_unit(self, workflow_service):
|
||||
"""Test validation fails with invalid timeout unit."""
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
"title": "Human Input",
|
||||
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
|
||||
"form_content": "Please provide your input",
|
||||
"inputs": [],
|
||||
"user_actions": [],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "invalid_unit",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_human_input_node_data_invalid_button_style(self, workflow_service):
|
||||
"""Test validation fails with invalid button style."""
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
"title": "Human Input",
|
||||
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
|
||||
"form_content": "Please provide your input",
|
||||
"inputs": [],
|
||||
"user_actions": [{"id": "submit", "title": "Submit", "button_style": "invalid_style"}],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "hour",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_human_input_node_data_email_delivery_config(self, workflow_service):
|
||||
"""Test validation of HumanInput node with email delivery configuration."""
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
"title": "Human Input",
|
||||
"delivery_methods": [
|
||||
{
|
||||
"type": "email",
|
||||
"enabled": True,
|
||||
"config": {
|
||||
"recipients": {
|
||||
"whole_workspace": False,
|
||||
"items": [{"type": "external", "email": "user@example.com"}],
|
||||
},
|
||||
"subject": "Input Required",
|
||||
"body": "Please provide your input",
|
||||
},
|
||||
}
|
||||
],
|
||||
"form_content": "Please provide your input",
|
||||
"inputs": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"output_variable_name": "feedback",
|
||||
"placeholder": {"type": "variable", "selector": ["node", "output"]},
|
||||
}
|
||||
],
|
||||
"user_actions": [
|
||||
{"id": "approve", "title": "Approve", "button_style": "accent"},
|
||||
{"id": "reject", "title": "Reject", "button_style": "ghost"},
|
||||
],
|
||||
"timeout": 7,
|
||||
"timeout_unit": "day",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Should not raise any exception
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_human_input_node_data_invalid_email_recipient(self, workflow_service):
|
||||
"""Test validation fails with invalid email recipient."""
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
"title": "Human Input",
|
||||
"delivery_methods": [
|
||||
{
|
||||
"type": "email",
|
||||
"enabled": True,
|
||||
"config": {
|
||||
"recipients": {
|
||||
"whole_workspace": False,
|
||||
"items": [{"type": "invalid_recipient_type", "email": "user@example.com"}],
|
||||
},
|
||||
"subject": "Input Required",
|
||||
"body": "Please provide your input",
|
||||
},
|
||||
}
|
||||
],
|
||||
"form_content": "Please provide your input",
|
||||
"inputs": [],
|
||||
"user_actions": [],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "hour",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_human_input_node_data_multiple_nodes_mixed_valid_invalid(self, workflow_service):
|
||||
"""Test validation with multiple nodes where some are valid and some invalid."""
|
||||
graph = {
|
||||
"nodes": [
|
||||
{"id": "node-1", "data": {"type": "start", "title": "Start"}},
|
||||
{
|
||||
"id": "node-2",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
"title": "Valid Human Input",
|
||||
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
|
||||
"form_content": "Valid input",
|
||||
"inputs": [],
|
||||
"user_actions": [],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "hour",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "node-3",
|
||||
"data": {
|
||||
"type": "human_input",
|
||||
"title": "Invalid Human Input",
|
||||
"delivery_methods": [{"type": "invalid_method", "enabled": True}],
|
||||
"form_content": "Invalid input",
|
||||
"inputs": [],
|
||||
"user_actions": [],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "hour",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
Reference in New Issue
Block a user