diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 0fe7148cf1..21700459bf 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -107,6 +107,7 @@ class ConsoleHumanInputFormApi(Resource): parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("action", type=str, required=True, location="json") args = parser.parse_args() + current_user, _ = current_account_with_tenant() # Submit the form service = HumanInputService(db.engine) @@ -114,6 +115,7 @@ class ConsoleHumanInputFormApi(Resource): form_id=form_id, selected_action_id=args["action"], form_data=args["inputs"], + user=current_user, ) return jsonify({}) diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 1e22ae94ce..d33d348b91 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -70,6 +70,7 @@ class HumanInputFormApi(WebApiResource): form_token=web_app_form_token, selected_action_id=args["action"], form_data=args["inputs"], + submission_end_user_id=_end_user.id, ) except FormNotFoundError: raise NotFoundError("Form not found") diff --git a/api/controllers/web/workflow_events.py b/api/controllers/web/workflow_events.py index 73430a3177..0421c4a457 100644 --- a/api/controllers/web/workflow_events.py +++ b/api/controllers/web/workflow_events.py @@ -3,7 +3,6 @@ Web App Workflow Resume APIs. """ import logging -import time from collections.abc import Generator from flask import Response @@ -14,42 +13,6 @@ from controllers.web.wraps import WebApiResource logger = logging.getLogger(__name__) -class WorkflowResumeWaitApi(WebApiResource): - """API for long-polling workflow resume wait.""" - - def get(self, task_id: str): - """ - Get workflow execution resume notification. - - GET /api/workflow//resume-wait - - This is a long-polling API that waits for workflow to resume from paused state. - """ - # TODO: Implement actual workflow status checking - # For now, return a basic response - - timeout = 30 # 30 seconds timeout for demo - start_time = time.time() - - while time.time() - start_time < timeout: - # TODO: Check workflow status from database/cache - # workflow_status = workflow_service.get_status(task_id) - - # For demo purposes, simulate different states - # In real implementation, this would check the actual workflow state - workflow_status = "paused" # or "running" or "ended" - - if workflow_status == "running": - return {"status": "running"}, 200 - elif workflow_status == "ended": - return {"status": "ended"}, 200 - - time.sleep(1) # Poll every second - - # Return paused status if timeout reached - return {"status": "paused"}, 200 - - class WorkflowEventsApi(WebApiResource): """API for getting workflow execution events after resume.""" diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 9e271423bf..fd913b807d 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -2,7 +2,7 @@ import contextvars import logging import threading import uuid -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import Any, Literal, TypeVar, Union, overload from flask import Flask, current_app @@ -24,16 +24,19 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory +from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.runtime import GraphRuntimeState from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory @@ -63,6 +66,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from: InvokeFrom, workflow_run_id: uuid.UUID, streaming: Literal[False], + pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any]: ... @overload @@ -75,6 +79,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from: InvokeFrom, workflow_run_id: uuid.UUID, streaming: Literal[True], + pause_state_config: PauseStateLayerConfig | None = None, ) -> Generator[Mapping | str, None, None]: ... @overload @@ -87,6 +92,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from: InvokeFrom, workflow_run_id: uuid.UUID, streaming: bool, + pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ... def generate( @@ -98,6 +104,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from: InvokeFrom, workflow_run_id: uuid.UUID, streaming: bool = True, + pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: """ Generate App response. @@ -215,6 +222,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=conversation, stream=streaming, + pause_state_config=pause_state_config, + ) + + def resume( + self, + *, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + conversation: Conversation, + message: Message, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + graph_runtime_state: GraphRuntimeState, + pause_state_config: PauseStateLayerConfig | None = None, + ): + """ + Resume a paused advanced chat execution. + """ + return self._generate( + workflow=workflow, + user=user, + invoke_from=application_generate_entity.invoke_from, + application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + conversation=conversation, + message=message, + stream=application_generate_entity.stream, + pause_state_config=pause_state_config, + graph_runtime_state=graph_runtime_state, ) def single_iteration_generate( @@ -395,8 +434,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Conversation | None = None, + message: Message | None = None, stream: bool = True, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, + pause_state_config: PauseStateLayerConfig | None = None, + graph_runtime_state: GraphRuntimeState | None = None, + graph_engine_layers: Sequence[GraphEngineLayer] = (), ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: """ Generate App response. @@ -410,12 +453,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param conversation: conversation :param stream: is stream """ - is_first_conversation = False - if not conversation: - is_first_conversation = True + is_first_conversation = conversation is None - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + if conversation is not None and message is not None: + pass + else: + conversation, message = self._init_generate_records(application_generate_entity, conversation) if is_first_conversation: # update conversation features @@ -438,6 +481,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) + # new thread with request context and contextvars context = contextvars.copy_context() @@ -453,6 +506,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): "variable_loader": variable_loader, "workflow_execution_repository": workflow_execution_repository, "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, }, ) @@ -498,6 +553,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): variable_loader: VariableLoader, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, + graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, ): """ Generate worker in a new thread. @@ -555,6 +612,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app=app, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, + graph_engine_layers=graph_engine_layers, + graph_runtime_state=graph_runtime_state, ) try: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index ee092e55c5..f06a6f9e9b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -64,6 +64,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, ): super().__init__( queue_manager=queue_manager, @@ -80,6 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._app = app self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + self._resume_graph_runtime_state = graph_runtime_state @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -103,7 +105,19 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not app_record: raise ValueError("App not found") - if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + resume_state = self._resume_graph_runtime_state + + if resume_state is not None: + graph_runtime_state = resume_state + variable_pool = graph_runtime_state.variable_pool + graph = self._init_graph( + graph_config=self._workflow.graph_dict, + graph_runtime_state=graph_runtime_state, + workflow_id=self._workflow.id, + tenant_id=self._workflow.tenant_id, + user_id=self.application_generate_entity.user_id, + ) + elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: # Handle single iteration or single loop run graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 0165c74295..6f6fe60b5a 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -23,6 +23,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager @@ -31,12 +32,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.runtime import GraphRuntimeState from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts -from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.account import Account from models.enums import WorkflowRunTriggeredFrom +from models.model import App, EndUser +from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs" @@ -63,6 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator): triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, ) -> Generator[Mapping[str, Any] | str, None, None]: ... @overload @@ -79,6 +84,7 @@ class WorkflowAppGenerator(BaseAppGenerator): triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any]: ... @overload @@ -95,6 +101,7 @@ class WorkflowAppGenerator(BaseAppGenerator): triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ... def generate( @@ -110,6 +117,7 @@ class WorkflowAppGenerator(BaseAppGenerator): triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: files: Sequence[Mapping[str, Any]] = args.get("files") or [] @@ -210,13 +218,40 @@ class WorkflowAppGenerator(BaseAppGenerator): streaming=streaming, root_node_id=root_node_id, graph_engine_layers=graph_engine_layers, + pause_state_config=pause_state_config, ) - def resume(self, *, workflow_run_id: str) -> None: + def resume( + self, + *, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + application_generate_entity: WorkflowAppGenerateEntity, + graph_runtime_state: GraphRuntimeState, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + graph_engine_layers: Sequence[GraphEngineLayer] = (), + pause_state_config: PauseStateLayerConfig | None = None, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, + ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ - @TBD + Resume a paused workflow execution using the persisted runtime state. """ - pass + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=application_generate_entity.invoke_from, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=application_generate_entity.stream, + variable_loader=variable_loader, + graph_engine_layers=graph_engine_layers, + graph_runtime_state=graph_runtime_state, + pause_state_config=pause_state_config, + ) def _generate( self, @@ -232,6 +267,8 @@ class WorkflowAppGenerator(BaseAppGenerator): variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, + pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -245,6 +282,8 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream """ + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + # init queue manager queue_manager = WorkflowAppQueueManager( task_id=application_generate_entity.task_id, @@ -253,6 +292,15 @@ class WorkflowAppGenerator(BaseAppGenerator): app_mode=app_model.mode, ) + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) + # new thread with request context and contextvars context = contextvars.copy_context() @@ -270,7 +318,8 @@ class WorkflowAppGenerator(BaseAppGenerator): "root_node_id": root_node_id, "workflow_execution_repository": workflow_execution_repository, "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": graph_engine_layers, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, }, ) @@ -372,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, + pause_state_config=None, ) def single_loop_generate( @@ -453,6 +503,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, + pause_state_config=None, ) def _generate_worker( @@ -466,6 +517,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, ) -> None: """ Generate worker in a new thread. @@ -511,6 +563,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, root_node_id=root_node_id, graph_engine_layers=graph_engine_layers, + graph_runtime_state=graph_runtime_state, ) try: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 894e6f397a..44e63c7c4d 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -43,6 +43,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), + graph_runtime_state: GraphRuntimeState | None = None, ): super().__init__( queue_manager=queue_manager, @@ -56,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self._root_node_id = root_node_id self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + self._resume_graph_runtime_state = graph_runtime_state @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -65,17 +67,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) - system_inputs = SystemVariable( - files=self.application_generate_entity.files, - user_id=self._sys_user_id, - app_id=app_config.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=app_config.workflow_id, - workflow_execution_id=self.application_generate_entity.workflow_execution_id, - ) + resume_state = self._resume_graph_runtime_state - # if only single iteration or single loop run is requested - if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + if resume_state is not None: + graph_runtime_state = resume_state + variable_pool = graph_runtime_state.variable_pool + elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, @@ -85,7 +82,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): inputs = self.application_generate_entity.inputs # Create a variable pool. - + system_inputs = SystemVariable( + files=self.application_generate_entity.files, + user_id=self._sys_user_id, + app_id=app_config.app_id, + timestamp=int(naive_utc_now().timestamp()), + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, @@ -95,7 +99,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - # init graph + # init graph (for both resumed and fresh runs when not single-step) + if resume_state is not None or not ( + self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run + ): graph = self._init_graph( graph_config=self._workflow.graph_dict, graph_runtime_state=graph_runtime_state, diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 61a3e1baca..e1b5352c2a 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Annotated, Literal, Self, TypeAlias from pydantic import BaseModel, Field @@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel): return self.generate_entity.entity +@dataclass(frozen=True) +class PauseStateLayerConfig: + """Configuration container for instantiating pause persistence layers.""" + + session_factory: Engine | sessionmaker[Session] + state_owner_user_id: str + + class PauseStatePersistenceLayer(GraphEngineLayer): def __init__( self, diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index e8d41b9387..8a26b2e91b 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import ( ) from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now +from models.engine import db from models.provider import ( LoadBalancingModelConfig, Provider, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index f45f15a6da..c36391c940 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -15,10 +15,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token -from core.ops.entities.config_entity import ( - OPS_FILE_PATH, - TracingProviderEnum, -) +from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -31,11 +28,10 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import get_message_data -from extensions.ext_database import db from extensions.ext_storage import storage +from models.engine import db from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog -from repositories.factory import DifyAPIRepositoryFactory from tasks.ops_trace_task import process_trace_tasks if TYPE_CHECKING: @@ -470,6 +466,8 @@ class TraceTask: @classmethod def _get_workflow_run_repo(cls): + from repositories.factory import DifyAPIRepositoryFactory + if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 631e3b77b2..a5196d66c0 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse from sqlalchemy import select -from extensions.ext_database import db +from models.engine import db from models.model import Message diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6c01056c4d..6f2826f634 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -2,27 +2,17 @@ from __future__ import annotations -from importlib import import_module -from typing import Any +from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from .factory import DifyCoreRepositoryFactory, RepositoryImportError +from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -_ATTRIBUTE_MODULE_MAP = { - "CeleryWorkflowExecutionRepository": "core.repositories.celery_workflow_execution_repository", - "CeleryWorkflowNodeExecutionRepository": "core.repositories.celery_workflow_node_execution_repository", - "DifyCoreRepositoryFactory": "core.repositories.factory", - "RepositoryImportError": "core.repositories.factory", - "SQLAlchemyWorkflowNodeExecutionRepository": "core.repositories.sqlalchemy_workflow_node_execution_repository", -} - -__all__ = list(_ATTRIBUTE_MODULE_MAP.keys()) - - -def __getattr__(name: str) -> Any: - module_path = _ATTRIBUTE_MODULE_MAP.get(name) - if module_path is None: - raise AttributeError(f"module 'core.repositories' has no attribute '{name}'") - module = import_module(module_path) - return getattr(module, name) - - -def __dir__() -> list[str]: # pragma: no cover - simple helper - return sorted(__all__) +__all__ = [ + "CeleryWorkflowExecutionRepository", + "CeleryWorkflowNodeExecutionRepository", + "DifyCoreRepositoryFactory", + "RepositoryImportError", + "SQLAlchemyWorkflowExecutionRepository", + "SQLAlchemyWorkflowNodeExecutionRepository", +] diff --git a/api/core/repositories/human_input_reposotiry.py b/api/core/repositories/human_input_reposotiry.py index 4fc84664dc..df9442e7d3 100644 --- a/api/core/repositories/human_input_reposotiry.py +++ b/api/core/repositories/human_input_reposotiry.py @@ -1,10 +1,11 @@ import dataclasses import json from collections.abc import Mapping, Sequence +from datetime import datetime from typing import Any from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, sessionmaker, selectinload +from sqlalchemy.orm import Session, selectinload, sessionmaker from core.workflow.nodes.human_input.entities import ( DeliveryChannelConfig, @@ -85,6 +86,53 @@ class _FormSubmissionImpl(FormSubmission): return json.loads(submitted_data) +@dataclasses.dataclass(frozen=True) +class HumanInputFormRecord: + form_id: str + workflow_run_id: str + node_id: str + tenant_id: str + definition: FormDefinition + rendered_content: str + expiration_time: datetime + selected_action_id: str | None + submitted_data: Mapping[str, Any] | None + submitted_at: datetime | None + submission_user_id: str | None + submission_end_user_id: str | None + completed_by_recipient_id: str | None + recipient_id: str | None + recipient_type: RecipientType | None + access_token: str | None + + @property + def submitted(self) -> bool: + return self.submitted_at is not None + + @classmethod + def from_models( + cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None + ) -> "HumanInputFormRecord": + return cls( + form_id=form_model.id, + workflow_run_id=form_model.workflow_run_id, + node_id=form_model.node_id, + tenant_id=form_model.tenant_id, + definition=FormDefinition.model_validate_json(form_model.form_definition), + rendered_content=form_model.rendered_content, + expiration_time=form_model.expiration_time, + selected_action_id=form_model.selected_action_id, + submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None, + submitted_at=form_model.submitted_at, + submission_user_id=form_model.submission_user_id, + submission_end_user_id=form_model.submission_end_user_id, + completed_by_recipient_id=form_model.completed_by_recipient_id, + recipient_id=recipient_model.id if recipient_model else None, + recipient_type=recipient_model.recipient_type if recipient_model else None, + access_token=recipient_model.access_token if recipient_model else None, + ) + + class HumanInputFormRepositoryImpl: def __init__( self, @@ -275,11 +323,74 @@ class HumanInputFormRepositoryImpl: return _FormSubmissionImpl(form_model=form_model) - def get_form_by_token(self, token: str, recipient_type: RecipientType | None = None): + +class HumanInputFormReadRepository: + """Read/write repository for fetching and submitting human input forms.""" + + def __init__(self, session_factory: sessionmaker | Engine): + if isinstance(session_factory, Engine): + session_factory = sessionmaker(bind=session_factory) + self._session_factory = session_factory + + def get_by_token(self, form_token: str) -> HumanInputFormRecord | None: query = ( select(HumanInputFormRecipient) .options(selectinload(HumanInputFormRecipient.form)) - .where() - + .where(HumanInputFormRecipient.access_token == form_token) + ) with self._session_factory(expire_on_commit=False) as session: - form_recipient = session.qu + recipient_model = session.scalars(query).first() + if recipient_model is None or recipient_model.form is None: + return None + return HumanInputFormRecord.from_models(recipient_model.form, recipient_model) + + def get_by_form_id_and_recipient_type( + self, + form_id: str, + recipient_type: RecipientType, + ) -> HumanInputFormRecord | None: + query = ( + select(HumanInputFormRecipient) + .options(selectinload(HumanInputFormRecipient.form)) + .where( + HumanInputFormRecipient.form_id == form_id, + HumanInputFormRecipient.recipient_type == recipient_type, + ) + ) + with self._session_factory(expire_on_commit=False) as session: + recipient_model = session.scalars(query).first() + if recipient_model is None or recipient_model.form is None: + return None + return HumanInputFormRecord.from_models(recipient_model.form, recipient_model) + + def mark_submitted( + self, + *, + form_id: str, + recipient_id: str | None, + selected_action_id: str, + form_data: Mapping[str, Any], + submission_user_id: str | None, + submission_end_user_id: str | None, + ) -> HumanInputFormRecord: + with self._session_factory(expire_on_commit=False) as session, session.begin(): + form_model = session.get(HumanInputForm, form_id) + if form_model is None: + raise FormNotFoundError(f"form not found, id={form_id}") + + recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None + + form_model.selected_action_id = selected_action_id + form_model.submitted_data = json.dumps(form_data) + form_model.submitted_at = naive_utc_now() + form_model.submission_user_id = submission_user_id + form_model.submission_end_user_id = submission_end_user_id + form_model.completed_by_recipient_id = recipient_id + + session.add(form_model) + session.flush() + session.refresh(form_model) + if recipient_model is not None: + session.refresh(recipient_model) + + return HumanInputFormRecord.from_models(form_model, recipient_model) diff --git a/api/core/workflow/graph_engine/command_source.py b/api/core/workflow/graph_engine/command_source.py deleted file mode 100644 index 2d0d4b8211..0000000000 --- a/api/core/workflow/graph_engine/command_source.py +++ /dev/null @@ -1,71 +0,0 @@ -import abc -from collections.abc import Callable -from dataclasses import dataclass -from enum import StrEnum -from typing import Annotated, TypeAlias, final - -from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator - -from core.workflow.nodes.base import BaseNode - - -@dataclass(frozen=True) -class CommandParams: - # `next_node_instance` is the instance of the next node to run. - next_node: BaseNode - - -class _CommandTag(StrEnum): - SUSPEND = "suspend" - STOP = "stop" - CONTINUE = "continue" - - -# Note: Avoid using the `_Command` class directly. -# Instead, use `CommandTypes` for type annotations. -class _Command(BaseModel, abc.ABC): - model_config = ConfigDict(frozen=True) - - tag: _CommandTag - - @field_validator("tag") - @classmethod - def validate_value_type(cls, value): - if value != cls.model_fields["tag"].default: - raise ValueError("Cannot modify 'tag'") - return value - - -@final -class StopCommand(_Command): - tag: _CommandTag = _CommandTag.STOP - - -@final -class SuspendCommand(_Command): - tag: _CommandTag = _CommandTag.SUSPEND - - -@final -class ContinueCommand(_Command): - tag: _CommandTag = _CommandTag.CONTINUE - - -def _get_command_tag(command: _Command): - return command.tag - - -CommandTypes: TypeAlias = Annotated[ - ( - Annotated[StopCommand, Tag(_CommandTag.STOP)] - | Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)] - | Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)] - ), - Discriminator(_get_command_tag), -] - -# `CommandSource` is a callable that takes a single argument of type `CommandParams` and -# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop. -# -# It must not modify the data inside `CommandParams`, including any attributes within its fields. -CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes] diff --git a/api/core/workflow/nodes/human_input/__init__.py b/api/core/workflow/nodes/human_input/__init__.py index d3861d2c99..1789604577 100644 --- a/api/core/workflow/nodes/human_input/__init__.py +++ b/api/core/workflow/nodes/human_input/__init__.py @@ -1,8 +1,3 @@ """ Human Input node implementation. """ - -from .entities import HumanInputNodeData -from .human_input_node import HumanInputNode - -__all__ = ["HumanInputNode", "HumanInputNodeData"] diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/core/workflow/nodes/human_input/entities.py index 06bf83bcfd..ce84e61fb0 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/core/workflow/nodes/human_input/entities.py @@ -269,16 +269,6 @@ class HumanInputNodeData(BaseNodeData): return variable_mappings -class HumanInputRequired(BaseModel): - """Event data for human input required.""" - - form_id: str - node_id: str - form_content: str - inputs: list[FormInput] - web_app_form_token: Optional[str] = None - - class FormDefinition(BaseModel): form_content: str inputs: list[FormInput] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index 02bfcf6b32..b114895958 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -1,4 +1,3 @@ -import dataclasses import logging from collections.abc import Generator, Mapping, Sequence from typing import Any @@ -18,11 +17,6 @@ _SELECTED_BRANCH_KEY = "selected_branch" logger = logging.getLogger(__name__) -@dataclasses.dataclass -class _FormSubmissionResult: - action_id: str - - class HumanInputNode(Node[HumanInputNodeData]): node_type = NodeType.HUMAN_INPUT execution_type = NodeExecutionType.BRANCH diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index fb53a9ec63..0179759a48 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -5,7 +5,7 @@ import json from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol from pydantic.json import pydantic_encoder @@ -13,6 +13,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import PauseReason from core.workflow.runtime.variable_pool import VariablePool +if TYPE_CHECKING: + from core.workflow.entities.pause_reason import PauseReason + class ReadyQueueProtocol(Protocol): """Structural interface required from ready queue implementations.""" @@ -59,7 +62,7 @@ class GraphExecutionProtocol(Protocol): aborted: bool error: Exception | None exceptions_count: int - pause_reasons: list[PauseReason] + pause_reasons: Sequence[PauseReason] def start(self) -> None: """Transition execution into the running state.""" diff --git a/api/models/human_input.py b/api/models/human_input.py index e5b2d332cb..b6f5b7911a 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -57,18 +57,19 @@ class HumanInputForm(DefaultFieldsMixin, Base): completed_by_recipient_id: Mapped[str | None] = mapped_column( StringUUID, - sa.ForeignKey("human_input_recipients.id"), nullable=True, ) deliveries: Mapped[list["HumanInputDelivery"]] = relationship( "HumanInputDelivery", + primaryjoin="HumanInputForm.id == foreign(HumanInputDelivery.form_id)", + uselist=True, back_populates="form", lazy="raise", ) completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship( - "HumanInputRecipient", - primaryjoin="HumanInputForm.completed_by_recipient_id == HumanInputRecipient.id", + "HumanInputFormRecipient", + primaryjoin="HumanInputForm.completed_by_recipient_id == foreign(HumanInputFormRecipient.id)", lazy="raise", viewonly=True, ) @@ -79,7 +80,6 @@ class HumanInputDelivery(DefaultFieldsMixin, Base): form_id: Mapped[str] = mapped_column( StringUUID, - sa.ForeignKey("human_input_forms.id"), nullable=False, ) delivery_method_type: Mapped[DeliveryMethodType] = mapped_column( @@ -91,11 +91,16 @@ class HumanInputDelivery(DefaultFieldsMixin, Base): form: Mapped[HumanInputForm] = relationship( "HumanInputForm", + uselist=False, + foreign_keys=[form_id], + primaryjoin="HumanInputDelivery.form_id == HumanInputForm.id", back_populates="deliveries", lazy="raise", ) + recipients: Mapped[list["HumanInputFormRecipient"]] = relationship( - "HumanInputRecipient", + "HumanInputFormRecipient", + primaryjoin="HumanInputDelivery.id == foreign(HumanInputFormRecipient.delivery_id)", uselist=True, back_populates="delivery", # Require explicit preloading @@ -162,6 +167,7 @@ class HumanInputFormRecipient(DefaultFieldsMixin, Base): uselist=False, foreign_keys=[delivery_id], back_populates="recipients", + primaryjoin="HumanInputFormRecipient.delivery_id == HumanInputDelivery.id", # Require explicit preloading lazy="raise", ) @@ -170,7 +176,7 @@ class HumanInputFormRecipient(DefaultFieldsMixin, Base): "HumanInputForm", uselist=False, foreign_keys=[form_id], - back_populates="recipients", + primaryjoin="HumanInputFormRecipient.form_id == HumanInputForm.id", # Require explicit preloading lazy="raise", ) diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index e3925af1e7..364cd3a894 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -6,38 +6,11 @@ by the core workflow module. These models are independent of the storage mechani and don't contain implementation details like tenant_id, app_id, etc. """ -import enum from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime -from typing import Annotated, Literal, TypeAlias -from pydantic import BaseModel, Field -from sqlalchemy import Sequence - - -class _PauseTypeEnum(enum.StrEnum): - human_input = enum.auto() - breakpoint = enum.auto() - scheduling = enum.auto() - - -class HumanInputPause(BaseModel): - type: Literal[_PauseTypeEnum.human_input] = _PauseTypeEnum.human_input - form_id: str - - -class SchedulingPause(BaseModel): - type: Literal[_PauseTypeEnum.scheduling] = _PauseTypeEnum.scheduling - - -PauseType: TypeAlias = Annotated[HumanInputPause | SchedulingPause, Field(discriminator="type")] - - -class PauseDetail(BaseModel): - node_id: str - node_title: str - pause_type: PauseType +from .pause_reason import PauseReason from core.workflow.entities.pause_reason import PauseReason @@ -100,7 +73,5 @@ class WorkflowPauseEntity(ABC): Returns a sequence of `PauseReason` objects describing the specific nodes and reasons for which the workflow execution was paused. - This information is related to, but distinct from, the `PauseReason` type - defined in `api/core/workflow/entities/pause_reason.py`. """ ... diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index a29365b3b2..298517a45a 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -505,7 +505,6 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Mark as resumed pause_model.resumed_at = naive_utc_now() - workflow_run.pause_id = None # type: ignore workflow_run.status = WorkflowExecutionStatus.RUNNING session.add(pause_model) diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index 0d67e286b0..c828cc60c2 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -84,3 +84,13 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): ) return list(self.session.scalars(query).all()) + + def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None: + """Get the trigger log associated with a workflow run.""" + query = ( + select(WorkflowTriggerLog) + .where(WorkflowTriggerLog.workflow_run_id == workflow_run_id) + .order_by(WorkflowTriggerLog.created_at.desc()) + .limit(1) + ) + return self.session.scalar(query) diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py index 138b8779ac..e78f1db532 100644 --- a/api/repositories/workflow_trigger_log_repository.py +++ b/api/repositories/workflow_trigger_log_repository.py @@ -109,3 +109,15 @@ class WorkflowTriggerLogRepository(Protocol): A sequence of recent WorkflowTriggerLog instances """ ... + + def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None: + """ + Retrieve a trigger log associated with a specific workflow run. + + Args: + workflow_run_id: Identifier of the workflow run + + Returns: + The matching WorkflowTriggerLog if present, None otherwise + """ + ... diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 764f9b3ab6..fc027b5bc1 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -14,12 +14,8 @@ from core.app.features.rate_limiting.rate_limit import rate_limit_context from enums.quota_type import QuotaType, unlimited from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser -from models.workflow import Workflow -from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError -from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun -from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError -from services.errors.llm import InvokeRateLimitError +from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import ChatflowExecutionParams, chatflow_execute_task @@ -259,7 +255,7 @@ class AppGenerateService: ): if workflow_run.status.is_ended(): # TODO(QuantumGhost): handled the ended scenario. - return + pass generator = AdvancedChatAppGenerator() diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 67c4f25376..56bb8f2eaf 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -1,29 +1,49 @@ -import abc -import json +import logging from collections.abc import Mapping from typing import Any -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, selectinload, sessionmaker +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker +from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord from core.workflow.nodes.human_input.entities import FormDefinition -from libs.datetime_utils import naive_utc_now from libs.exception import BaseHTTPException from models.account import Account -from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType +from models.human_input import RecipientType +from models.model import App, AppMode +from models.workflow import WorkflowRun +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.workflow.entities import WorkflowResumeTaskData +from tasks.app_generate.workflow_execute_task import resume_chatflow_execution +from tasks.async_workflow_tasks import resume_workflow_execution class Form: - def __init__(self, form_model: HumanInputForm): - self._form_model = form_model + def __init__(self, record: HumanInputFormRecord): + self._record = record - @abc.abstractmethod def get_definition(self) -> FormDefinition: - pass + return self._record.definition - @abc.abstractmethod + @property def submitted(self) -> bool: - pass + return self._record.submitted + + @property + def id(self) -> str: + return self._record.form_id + + @property + def workflow_run_id(self) -> str: + return self._record.workflow_run_id + + @property + def recipient_id(self) -> str | None: + return self._record.recipient_id + + @property + def recipient_type(self) -> RecipientType | None: + return self._record.recipient_type class HumanInputError(Exception): @@ -49,93 +69,148 @@ class WebAppDeliveryNotEnabledError(HumanInputError, BaseException): pass +logger = logging.getLogger(__name__) + + class HumanInputService: def __init__( self, session_factory: sessionmaker[Session] | Engine, + form_repository: HumanInputFormReadRepository | None = None, ): if isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory + self._form_repository = form_repository or HumanInputFormReadRepository(session_factory) def get_form_by_token(self, form_token: str) -> Form | None: - query = ( - select(HumanInputFormRecipient) - .options(selectinload(HumanInputFormRecipient.form)) - .where(HumanInputFormRecipient.access_token == form_token) - ) - with self._session_factory(expire_on_commit=False) as session: - recipient = session.scalars(query).first() - if recipient is None: + record = self._form_repository.get_by_token(form_token) + if record is None: return None + return Form(record) - return Form(recipient.form) - - def get_form_by_id(self, form_id: str) -> Form | None: - query = select(HumanInputForm).where(HumanInputForm.id == form_id) - with self._session_factory(expire_on_commit=False) as session: - form_model = session.scalars(query).first() - if form_model is None: - return None - - return Form(form_model) - - def submit_form_by_id(self, form_id: str, user: Account, selected_action_id: str, form_data: Mapping[str, Any]): - recipient_query = ( - select(HumanInputFormRecipient) - .options(selectinload(HumanInputFormRecipient.form)) - .where( - HumanInputFormRecipient.recipient_type == RecipientType.WEBAPP, - HumanInputFormRecipient.form_id == form_id, - ) + def get_form_by_id(self, form_id: str, recipient_type: RecipientType = RecipientType.WEBAPP) -> Form | None: + record = self._form_repository.get_by_form_id_and_recipient_type( + form_id=form_id, + recipient_type=recipient_type, ) + if record is None: + return None + return Form(record) - with self._session_factory(expire_on_commit=False) as session: - recipient_model = session.scalars(recipient_query).first() + def get_form_definition_by_id(self, form_id: str) -> Form | None: + form = self.get_form_by_id(form_id, recipient_type=RecipientType.WEBAPP) + if form is None: + return None + self._ensure_not_submitted(form) + return form - if recipient_model is None: + def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None: + form = self.get_form_by_token(form_token) + if form is None or form.recipient_type != recipient_type: + return None + self._ensure_not_submitted(form) + return form + + def submit_form_by_id( + self, + form_id: str, + selected_action_id: str, + form_data: Mapping[str, Any], + user: Account | None = None, + ): + form = self.get_form_by_id(form_id, recipient_type=RecipientType.WEBAPP) + if form is None: raise WebAppDeliveryNotEnabledError() - form_model = recipient_model.form - form = Form(form_model) - if form.submitted: - raise FormSubmittedError(form_model.id) + self._ensure_not_submitted(form) - with self._session_factory(expire_on_commit=False) as session, session.begin(): - form_model.selected_action_id = selected_action_id - form_model.submitted_data = json.dumps(form_data) - form_model.submitted_at = naive_utc_now() - form_model.submission_user_id = user.id - - form_model.completed_by_recipient_id = recipient_model.id - session.add(form_model) - # TODO: restart the execution of paused workflow - - def submit_form_by_token(self, form_token: str, selected_action_id: str, form_data: Mapping[str, Any]): - recipient_query = ( - select(HumanInputFormRecipient) - .options(selectinload(HumanInputFormRecipient.form)) - .where( - HumanInputFormRecipient.form_id == form_token, - ) + result = self._form_repository.mark_submitted( + form_id=form.id, + recipient_id=form.recipient_id, + selected_action_id=selected_action_id, + form_data=form_data, + submission_user_id=user.id if user else None, + submission_end_user_id=None, ) - with self._session_factory(expire_on_commit=False) as session: - recipient_model = session.scalars(recipient_query).first() + self._enqueue_resume(result.workflow_run_id) - if recipient_model is None: + def submit_form_by_token( + self, + recipient_type: RecipientType, + form_token: str, + selected_action_id: str, + form_data: Mapping[str, Any], + submission_end_user_id: str | None = None, + ): + form = self.get_form_by_token(form_token) + if form is None or form.recipient_type != recipient_type: raise WebAppDeliveryNotEnabledError() - form_model = recipient_model.form - form = Form(form_model) + self._ensure_not_submitted(form) + + result = self._form_repository.mark_submitted( + form_id=form.id, + recipient_id=form.recipient_id, + selected_action_id=selected_action_id, + form_data=form_data, + submission_user_id=None, + submission_end_user_id=submission_end_user_id, + ) + + self._enqueue_resume(result.workflow_run_id) + + def _ensure_not_submitted(self, form: Form) -> None: if form.submitted: - raise FormSubmittedError(form_model.id) + raise FormSubmittedError(form.id) - with self._session_factory(expire_on_commit=False) as session, session.begin(): - form_model.selected_action_id = selected_action_id - form_model.submitted_data = json.dumps(form_data) - form_model.submitted_at = naive_utc_now() - form_model.submission_user_id = user.id + def _enqueue_resume(self, workflow_run_id: str) -> None: + with self._session_factory(expire_on_commit=False) as session: + trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id) - form_model.completed_by_recipient_id = recipient_model.id - session.add(form_model) + if trigger_log is not None: + payload = WorkflowResumeTaskData( + workflow_trigger_log_id=trigger_log.id, + workflow_run_id=workflow_run_id, + ) + + try: + resume_workflow_execution.apply_async( + kwargs={"task_data_dict": payload.model_dump()}, + queue=trigger_log.queue_name, + ) + except Exception: # pragma: no cover + logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id) + return + + if self._enqueue_chatflow_resume(workflow_run_id): + return + + logger.warning("No workflow trigger log bound to workflow run %s; skipping resume dispatch", workflow_run_id) + + def _enqueue_chatflow_resume(self, workflow_run_id: str) -> bool: + with self._session_factory(expire_on_commit=False) as session: + workflow_run = session.get(WorkflowRun, workflow_run_id) + if workflow_run is None: + return False + + app = session.get(App, workflow_run.app_id) + + if app is None: + return False + + if app.mode != AppMode.ADVANCED_CHAT.value: + return False + + try: + resume_chatflow_execution.apply_async( + kwargs={"payload": {"workflow_run_id": workflow_run_id}}, + queue="chatflow_execute", + ) + except Exception: # pragma: no cover + logger.exception("Failed to enqueue chatflow resume for workflow run %s", workflow_run_id) + return False + + return True diff --git a/api/services/workflow/entities.py b/api/services/workflow/entities.py index 70ec8d6e2a..33c2aedaef 100644 --- a/api/services/workflow/entities.py +++ b/api/services/workflow/entities.py @@ -98,6 +98,13 @@ class WorkflowTaskData(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) +class WorkflowResumeTaskData(BaseModel): + """Payload for workflow resumption tasks.""" + + workflow_trigger_log_id: str + workflow_run_id: str + + class AsyncTriggerExecutionResult(BaseModel): """Result from async trigger-based workflow execution""" diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index d0f34d3a54..b903d8df5f 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -5,7 +5,6 @@ from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker import contexts -from core.workflow.entities.workflow_pause import PauseDetail from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ( @@ -160,9 +159,3 @@ class WorkflowRunService: app_id=app_model.id, workflow_run_id=run_id, ) - - def get_pause_details(self, workflow_run_id: str) -> Sequence[PauseDetail]: - pause = self._workflow_run_repo.get_workflow_pause(workflow_run_id) - if pause is None: - return [] - return pause.get_pause_details() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 260ed0f89e..096ef36b78 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -949,7 +949,7 @@ class WorkflowService: node_data = node.get("data", {}) node_type = node_data.get("type") - if node_type == "human_input": + if node_type == NodeType.HUMAN_INPUT: self._validate_human_input_node_data(node_data) def validate_features_structure(self, app_model: App, features: dict): diff --git a/api/tasks/app_generate/__init__.py b/api/tasks/app_generate/__init__.py index a247e7bc89..6c457d985a 100644 --- a/api/tasks/app_generate/__init__.py +++ b/api/tasks/app_generate/__init__.py @@ -1,5 +1,3 @@ from .workflow_execute_task import chatflow_execute_task -__all__ = [ - "chatflow_execute_taskch" -] +__all__ = ["chatflow_execute_task"] diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 57acc9725e..7f91410cf1 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -8,18 +8,21 @@ from typing import Annotated, Any, TypeAlias, Union from celery import shared_task from flask import current_app, json from pydantic import BaseModel, Discriminator, Field, Tag -from sqlalchemy import Engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session, sessionmaker from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.entities.app_invoke_entities import ( - InvokeFrom, -) +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext +from core.repositories import DifyCoreRepositoryFactory +from core.workflow.runtime import GraphRuntimeState from extensions.ext_database import db from libs.flask_utils import set_login_user from models.account import Account -from models.model import App, AppMode, EndUser -from models.workflow import Workflow +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.model import App, AppMode, Conversation, EndUser, Message +from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory logger = logging.getLogger(__name__) @@ -129,6 +132,11 @@ class _ChatflowRunner: workflow = session.get(Workflow, exec_params.workflow_id) app = session.get(App, workflow.app_id) + pause_config = PauseStateLayerConfig( + session_factory=self._session_factory, + state_owner_user_id=workflow.created_by, + ) + user = self._resolve_user() chat_generator = AdvancedChatAppGenerator() @@ -144,6 +152,7 @@ class _ChatflowRunner: invoke_from=exec_params.invoke_from, streaming=exec_params.streaming, workflow_run_id=workflow_run_id, + pause_state_config=pause_config, ) if not exec_params.streaming: return response @@ -174,11 +183,135 @@ class _ChatflowRunner: return user +def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None: + role = CreatorUserRole(workflow_run.created_by_role) + if role == CreatorUserRole.ACCOUNT: + user = session.get(Account, workflow_run.created_by) + if user: + user.set_tenant_id(workflow_run.tenant_id) + return user + + return session.get(EndUser, workflow_run.created_by) + + @shared_task(queue="chatflow_execute") def chatflow_execute_task(payload: str) -> Mapping[str, Any] | None: exec_params = ChatflowExecutionParams.model_validate_json(payload) - print("chatflow_execute_task run with params", exec_params) + logger.info("chatflow_execute_task run with params: %s", exec_params) runner = _ChatflowRunner(db.engine, exec_params=exec_params) return runner.run() + + +@shared_task(queue="chatflow_execute", name="resume_chatflow_execution") +def resume_chatflow_execution(payload: dict[str, Any]) -> None: + workflow_run_id = payload["workflow_run_id"] + + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory) + + pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) + if pause_entity is None: + logger.warning("No pause entity found for workflow run %s", workflow_run_id) + return + + try: + resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) + except Exception: + logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id) + return + + generate_entity = resumption_context.get_generate_entity() + if not isinstance(generate_entity, AdvancedChatAppGenerateEntity): + logger.error( + "Resumption entity is not AdvancedChatAppGenerateEntity for workflow run %s (found %s)", + workflow_run_id, + type(generate_entity), + ) + return + + graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) + + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = session.get(WorkflowRun, workflow_run_id) + if workflow_run is None: + logger.warning("Workflow run %s not found during resume", workflow_run_id) + return + + workflow = session.get(Workflow, workflow_run.workflow_id) + if workflow is None: + logger.warning("Workflow %s not found during resume", workflow_run.workflow_id) + return + + app_model = session.get(App, workflow_run.app_id) + if app_model is None: + logger.warning("App %s not found during resume", workflow_run.app_id) + return + + if generate_entity.conversation_id is None: + logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id) + return + + conversation = session.get(Conversation, generate_entity.conversation_id) + if conversation is None: + logger.warning( + "Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id + ) + return + + message = session.scalar( + select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc()) + ) + if message is None: + logger.warning("Message not found for workflow run %s", workflow_run_id) + return + + user = _resolve_user_for_run(session, workflow_run) + if user is None: + logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id) + return + + workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity) + + pause_config = PauseStateLayerConfig( + session_factory=session_factory, + state_owner_user_id=workflow.created_by, + ) + + try: + triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) + except ValueError: + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=app_model.id, + triggered_from=triggered_from, + ) + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=app_model.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + generator = AdvancedChatAppGenerator() + + try: + generator.resume( + app_model=app_model, + workflow=workflow, + user=user, + conversation=conversation, + message=message, + application_generate_entity=generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + graph_runtime_state=graph_runtime_state, + pause_state_config=pause_config, + ) + except Exception: + logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id) + raise diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index f8aac5b469..231a63bf70 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -5,6 +5,7 @@ These tasks handle workflow execution for different subscription tiers with appropriate retry policies and error handling. """ +import logging from datetime import UTC, datetime from typing import Any @@ -14,23 +15,37 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext +from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.trigger_post_layer import TriggerPostLayer +from core.repositories import DifyCoreRepositoryFactory +from core.workflow.runtime import GraphRuntimeState from extensions.ext_database import db from models.account import Account -from models.enums import CreatorUserRole, WorkflowTriggerStatus +from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant from models.trigger import WorkflowTriggerLog -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom +from repositories.factory import DifyAPIRepositoryFactory from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import WorkflowNotFoundError from services.workflow.entities import ( TriggerData, + WorkflowResumeTaskData, WorkflowTaskData, ) from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy +logger = logging.getLogger(__name__) + +_TRIGGER_TO_RUN_SOURCE = { + AppTriggerType.TRIGGER_WEBHOOK: WorkflowRunTriggeredFrom.WEBHOOK, + AppTriggerType.TRIGGER_SCHEDULE: WorkflowRunTriggeredFrom.SCHEDULE, + AppTriggerType.TRIGGER_PLUGIN: WorkflowRunTriggeredFrom.PLUGIN, +} + @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) def execute_workflow_professional(task_data_dict: dict[str, Any]): @@ -144,6 +159,11 @@ def _execute_workflow_common( if trigger_data.workflow_id: args["workflow_id"] = str(trigger_data.workflow_id) + pause_config = PauseStateLayerConfig( + session_factory=session_factory, + state_owner_user_id=workflow.created_by, + ) + # Execute the workflow with the trigger type generator.generate( app_model=app_model, @@ -159,6 +179,7 @@ def _execute_workflow_common( # TODO: Re-enable TimeSliceLayer after the HITL release. TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory), ], + pause_state_config=pause_config, ) except Exception as e: @@ -176,6 +197,117 @@ def _execute_workflow_common( session.commit() +@shared_task(name="resume_workflow_execution") +def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None: + """Resume a paused workflow run via Celery.""" + task_data = WorkflowResumeTaskData.model_validate(task_data_dict) + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory) + + with session_factory() as session: + trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id) + if not trigger_log: + logger.warning("Trigger log not found for resumption: %s", task_data.workflow_trigger_log_id) + return + + pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id) + if pause_entity is None: + logger.warning("No pause state for workflow run %s", task_data.workflow_run_id) + return + + try: + resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) + except Exception as exc: + logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id) + raise exc + + generate_entity = resumption_context.get_generate_entity() + if not isinstance(generate_entity, WorkflowAppGenerateEntity): + logger.error( + "Unsupported resumption entity for workflow run %s: %s", + task_data.workflow_run_id, + type(generate_entity), + ) + return + + graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) + + workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id)) + if workflow is None: + raise WorkflowNotFoundError(f"Workflow not found: {trigger_log.workflow_id}") + + app_model = session.scalar(select(App).where(App.id == trigger_log.app_id)) + if app_model is None: + raise WorkflowNotFoundError(f"App not found: {trigger_log.app_id}") + + user = _get_user(session, trigger_log) + + cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity( + queue=trigger_log.queue_name, + schedule_strategy=AsyncWorkflowSystemStrategy, + granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY, + ) + cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity) + + try: + trigger_type = AppTriggerType(trigger_log.trigger_type) + except ValueError: + trigger_type = AppTriggerType.UNKNOWN + + triggered_from = _TRIGGER_TO_RUN_SOURCE.get(trigger_type, WorkflowRunTriggeredFrom.APP_RUN) + + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=generate_entity.app_config.app_id, + triggered_from=triggered_from, + ) + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + pause_config = PauseStateLayerConfig( + session_factory=session_factory, + state_owner_user_id=workflow.created_by, + ) + + workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity) + + trigger_log.status = WorkflowTriggerStatus.RUNNING + trigger_log_repo.update(trigger_log) + session.commit() + + generator = WorkflowAppGenerator() + start_time = datetime.now(UTC) + + try: + generator.resume( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=generate_entity, + graph_runtime_state=graph_runtime_state, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + graph_engine_layers=[ + TimeSliceLayer(cfs_plan_scheduler), + TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory), + ], + pause_state_config=pause_config, + ) + except Exception as exc: + trigger_log.status = WorkflowTriggerStatus.FAILED + trigger_log.error = str(exc) + trigger_log.finished_at = datetime.now(UTC) + trigger_log_repo.update(trigger_log) + session.commit() + raise + + def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser: """Compose user from trigger log""" tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id)) diff --git a/api/tests/integration/core/app/apps/test_pause_resume_integration.py b/api/tests/integration/core/app/apps/test_pause_resume_integration.py new file mode 100644 index 0000000000..fe07cb0010 --- /dev/null +++ b/api/tests/integration/core/app/apps/test_pause_resume_integration.py @@ -0,0 +1,278 @@ +import sys +import time +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +API_DIR = str(Path(__file__).resolve().parents[5]) +if API_DIR not in sys.path: + sys.path.insert(0, API_DIR) + +import core.workflow.nodes.human_input.entities # noqa: F401 +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities import GraphInitParams +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph import Graph +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.entities.commands import PauseCommand +from core.workflow.graph_events import ( + GraphEngineEvent, + GraphRunPausedEvent, + GraphRunSucceededEvent, + NodeRunSucceededEvent, +) +from core.workflow.node_events import NodeRunResult +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + +if "core.ops.ops_trace_manager" not in sys.modules: + ops_stub = ModuleType("core.ops.ops_trace_manager") + + class _StubTraceQueueManager: + def __init__(self, *_, **__): + pass + + ops_stub.TraceQueueManager = _StubTraceQueueManager + sys.modules["core.ops.ops_trace_manager"] = ops_stub + + +class _StubToolNodeData(BaseNodeData): + pass + + +class _StubToolNode(Node): + node_type = NodeType.TOOL + + def init_node_data(self, data): + self._node_data = _StubToolNodeData.model_validate(data) + + def _get_error_strategy(self): + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self): + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + + def _run(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"value": f"{self.id}-done"}, + ) + + +def _patch_tool_node(mocker): + from core.workflow.nodes import node_factory + + custom_mapping = dict(node_factory.NODE_TYPE_CLASSES_MAPPING) + custom_versions = dict(custom_mapping[NodeType.TOOL]) + custom_versions[node_factory.LATEST_VERSION] = _StubToolNode + custom_mapping[NodeType.TOOL] = custom_versions + mocker.patch("core.workflow.nodes.node_factory.NODE_TYPE_CLASSES_MAPPING", custom_mapping) + + +def _build_graph(runtime_state: GraphRuntimeState) -> Graph: + params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={}, + user_id="user", + user_from="account", + invoke_from="service-api", + call_depth=0, + ) + + start_data = StartNodeData(title="start", variables=[]) + start_node = StartNode( + id="start", + config={"id": "start", "data": start_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + start_node.init_node_data(start_data.model_dump()) + + tool_data = _StubToolNodeData(title="tool") + tool_a = _StubToolNode( + id="tool_a", + config={"id": "tool_a", "data": tool_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + tool_a.init_node_data(tool_data.model_dump()) + + tool_b = _StubToolNode( + id="tool_b", + config={"id": "tool_b", "data": tool_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + tool_b.init_node_data(tool_data.model_dump()) + + tool_c = _StubToolNode( + id="tool_c", + config={"id": "tool_c", "data": tool_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + tool_c.init_node_data(tool_data.model_dump()) + + end_data = EndNodeData( + title="end", + outputs=[VariableSelector(variable="result", value_selector=["tool_c", "value"])], + desc=None, + ) + end_node = EndNode( + id="end", + config={"id": "end", "data": end_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + end_node.init_node_data(end_data.model_dump()) + + return ( + Graph.new() + .add_root(start_node) + .add_node(tool_a) + .add_node(tool_b) + .add_node(tool_c) + .add_node(end_node) + .add_edge("tool_a", "tool_b") + .add_edge("tool_b", "tool_c") + .add_edge("tool_c", "end") + .build() + ) + + +def _build_runtime_state(run_id: str) -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + variable_pool.system_variables.workflow_execution_id = run_id + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]: + command_channel = InMemoryChannel() + graph = _build_graph(runtime_state) + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=command_channel, + ) + + events: list[GraphEngineEvent] = [] + for event in engine.run(): + if isinstance(event, NodeRunSucceededEvent) and pause_on and event.node_id == pause_on: + command_channel.send_command(PauseCommand(reason="test pause")) + engine._command_processor.process_commands() # type: ignore[attr-defined] + events.append(event) + return events + + +def _node_successes(events: list[GraphEngineEvent]) -> list[str]: + return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)] + + +def test_workflow_app_pause_resume_matches_baseline(mocker): + _patch_tool_node(mocker) + + baseline_state = _build_runtime_state("baseline") + baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) + assert isinstance(baseline_events[-1], GraphRunSucceededEvent) + baseline_nodes = _node_successes(baseline_events) + baseline_outputs = baseline_state.outputs + + paused_state = _build_runtime_state("paused-run") + paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") + assert isinstance(paused_events[-1], GraphRunPausedEvent) + paused_nodes = _node_successes(paused_events) + snapshot = paused_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + generator = wf_app_gen_module.WorkflowAppGenerator() + + def _fake_generate(**kwargs): + state: GraphRuntimeState = kwargs["graph_runtime_state"] + events = _run_with_optional_pause(state, pause_on=None) + return _node_successes(events) + + mocker.patch.object(generator, "_generate", side_effect=_fake_generate) + + resumed_nodes = generator.resume( + app_model=SimpleNamespace(mode="workflow"), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), + graph_runtime_state=resumed_state, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + ) + + assert paused_nodes + resumed_nodes == baseline_nodes + assert resumed_state.outputs == baseline_outputs + + +def test_advanced_chat_pause_resume_matches_baseline(mocker): + _patch_tool_node(mocker) + + baseline_state = _build_runtime_state("adv-baseline") + baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) + assert isinstance(baseline_events[-1], GraphRunSucceededEvent) + baseline_nodes = _node_successes(baseline_events) + baseline_outputs = baseline_state.outputs + + paused_state = _build_runtime_state("adv-paused") + paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") + assert isinstance(paused_events[-1], GraphRunPausedEvent) + paused_nodes = _node_successes(paused_events) + snapshot = paused_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + generator = adv_app_gen_module.AdvancedChatAppGenerator() + + def _fake_generate(**kwargs): + state: GraphRuntimeState = kwargs["graph_runtime_state"] + events = _run_with_optional_pause(state, pause_on=None) + return _node_successes(events) + + mocker.patch.object(generator, "_generate", side_effect=_fake_generate) + + resumed_nodes = generator.resume( + app_model=SimpleNamespace(mode="workflow"), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + conversation=SimpleNamespace(id="conv"), + message=SimpleNamespace(id="msg"), + application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_runtime_state=resumed_state, + ) + + assert paused_nodes + resumed_nodes == baseline_nodes + assert resumed_state.outputs == baseline_outputs diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py index 80895553cb..65f3007b01 100644 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py +++ b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py @@ -61,8 +61,8 @@ class ConcurrentPublisher: messages.append(message) if self.delay > 0: time.sleep(self.delay) - except Exception as e: - _logger.error("Publisher %s error: %s", thread_id, e) + except Exception: + _logger.exception("Pubmsg=lisher %s", thread_id) with self._lock: self.published_messages.append(messages) @@ -308,8 +308,8 @@ def measure_throughput( try: operation() count += 1 - except Exception as e: - _logger.error("Operation failed: %s", e) + except Exception: + _logger.exception("Operation failed") break elapsed = time.time() - start_time diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py index 83ac3a5591..e9ca28b8d6 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py @@ -1,4 +1,26 @@ +import sys +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock + +if "core.ops.ops_trace_manager" not in sys.modules: + stub_module = ModuleType("core.ops.ops_trace_manager") + + class _StubTraceQueueManager: + def __init__(self, *_, **__): + pass + + stub_module.TraceQueueManager = _StubTraceQueueManager + sys.modules["core.ops.ops_trace_manager"] = stub_module + from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator +from tests.unit_tests.core.workflow.graph_engine.test_pause_resume_state import ( + _build_pausing_graph, + _build_runtime_state, + _node_successes, + _PausingNode, + _PausingNodeData, + _run_graph, +) def test_should_prepare_user_inputs_defaults_to_true(): @@ -17,3 +39,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false(): args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False} assert WorkflowAppGenerator()._should_prepare_user_inputs(args) + + +def test_resume_delegates_to_generate(mocker): + generator = WorkflowAppGenerator() + mock_generate = mocker.patch.object(generator, "_generate", return_value="ok") + + application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger") + runtime_state = MagicMock(name="runtime-state") + pause_config = MagicMock(name="pause-config") + + result = generator.resume( + app_model=MagicMock(), + workflow=MagicMock(), + user=MagicMock(), + application_generate_entity=application_generate_entity, + graph_runtime_state=runtime_state, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + graph_engine_layers=("layer",), + pause_state_config=pause_config, + variable_loader=MagicMock(), + ) + + assert result == "ok" + mock_generate.assert_called_once() + kwargs = mock_generate.call_args.kwargs + assert kwargs["graph_runtime_state"] is runtime_state + assert kwargs["pause_state_config"] is pause_config + assert kwargs["streaming"] is False + assert kwargs["invoke_from"] == "debugger" + + +def test_generate_appends_pause_layer_and_forwards_state(mocker): + generator = WorkflowAppGenerator() + + mock_queue_manager = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager) + + fake_current_app = MagicMock() + fake_current_app._get_current_object.return_value = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app) + + mocker.patch( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", + return_value="converted", + ) + mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response") + mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock()) + + pause_layer = MagicMock(name="pause-layer") + mocker.patch( + "core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", + return_value=pause_layer, + ) + + dummy_session = MagicMock() + dummy_session.close = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session) + + worker_kwargs: dict[str, object] = {} + + class DummyThread: + def __init__(self, target, kwargs): + worker_kwargs["target"] = target + worker_kwargs["kwargs"] = kwargs + + def start(self): + return None + + mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread) + + app_model = SimpleNamespace(mode="workflow") + app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf") + application_generate_entity = SimpleNamespace( + task_id="task", + user_id="user", + invoke_from="service-api", + app_config=app_config, + files=[], + stream=True, + workflow_execution_id="run", + ) + + graph_runtime_state = MagicMock() + + result = generator._generate( + app_model=app_model, + workflow=MagicMock(), + user=MagicMock(), + application_generate_entity=application_generate_entity, + invoke_from="service-api", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + streaming=True, + graph_engine_layers=("base-layer",), + graph_runtime_state=graph_runtime_state, + pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"), + ) + + assert result == "converted" + assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer) + assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state + + +def test_resume_path_runs_worker_with_runtime_state(mocker): + generator = WorkflowAppGenerator() + runtime_state = MagicMock(name="runtime-state") + + pause_layer = MagicMock(name="pause-layer") + mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer) + + queue_manager = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager) + + mocker.patch.object(generator, "_handle_response", return_value="raw-response") + mocker.patch( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", + side_effect=lambda response, invoke_from: response, + ) + + fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock()) + mocker.patch("core.app.apps.workflow.app_generator.db", fake_db) + + workflow = SimpleNamespace( + id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1" + ) + end_user = SimpleNamespace(session_id="end-user-session") + app_record = SimpleNamespace(id="app") + + session = MagicMock() + session.__enter__.return_value = session + session.__exit__.return_value = False + session.scalar.side_effect = [workflow, end_user, app_record] + mocker.patch("core.app.apps.workflow.app_generator.Session", return_value=session) + + runner_instance = MagicMock() + + def runner_ctor(**kwargs): + assert kwargs["graph_runtime_state"] is runtime_state + return runner_instance + + mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor) + + class ImmediateThread: + def __init__(self, target, kwargs): + target(**kwargs) + + def start(self): + return None + + mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread) + + mocker.patch( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner") + + app_model = SimpleNamespace(mode="workflow") + app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow") + application_generate_entity = SimpleNamespace( + task_id="task", + user_id="user", + invoke_from="service-api", + app_config=app_config, + files=[], + stream=True, + workflow_execution_id="run", + trace_manager=MagicMock(), + ) + + result = generator.resume( + app_model=app_model, + workflow=workflow, + user=MagicMock(), + application_generate_entity=application_generate_entity, + graph_runtime_state=runtime_state, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + pause_state_config=pause_config, + ) + + assert result == "raw-response" + runner_instance.run.assert_called_once() + queue_manager.graph_runtime_state = runtime_state diff --git a/api/tests/unit_tests/core/human_input_form_test.py b/api/tests/unit_tests/core/human_input_form_test.py deleted file mode 100644 index 25e1707871..0000000000 --- a/api/tests/unit_tests/core/human_input_form_test.py +++ /dev/null @@ -1,317 +0,0 @@ -""" -Tests for HumanInputForm domain model and repository. -""" - -import json -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -from core.repositories.sqlalchemy_human_input_form_repository import SQLAlchemyHumanInputFormRepository -from core.workflow.entities.human_input_form import HumanInputForm, HumanInputFormStatus - - -class TestHumanInputForm: - """Test cases for HumanInputForm domain model.""" - - def test_create_form(self): - """Test creating a new form.""" - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]}, - rendered_content="
Test form
", - ) - - assert form.id_ == "test-form-id" - assert form.workflow_run_id == "test-workflow-run" - assert form.status == HumanInputFormStatus.WAITING - assert form.can_be_submitted - assert not form.is_submitted - assert not form.is_expired - assert form.is_waiting - - def test_submit_form(self): - """Test submitting a form.""" - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]}, - rendered_content="
Test form
", - ) - - form.submit( - data={"field1": "value1"}, - action="submit", - submission_user_id="user123", - ) - - assert form.is_submitted - assert not form.can_be_submitted - assert form.status == HumanInputFormStatus.SUBMITTED - assert form.submission is not None - assert form.submission.data == {"field1": "value1"} - assert form.submission.action == "submit" - assert form.submission.submission_user_id == "user123" - - def test_submit_form_invalid_action(self): - """Test submitting a form with invalid action.""" - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]}, - rendered_content="
Test form
", - ) - - with pytest.raises(ValueError, match="Invalid action: invalid_action"): - form.submit(data={}, action="invalid_action") - - def test_submit_expired_form(self): - """Test submitting an expired form should fail.""" - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]}, - rendered_content="
Test form
", - ) - form.expire() - - with pytest.raises(ValueError, match="Form cannot be submitted in status: expired"): - form.submit(data={}, action="submit") - - def test_expire_form(self): - """Test expiring a form.""" - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]}, - rendered_content="
Test form
", - ) - - form.expire() - assert form.is_expired - assert not form.can_be_submitted - assert form.status == HumanInputFormStatus.EXPIRED - - def test_expire_already_submitted_form(self): - """Test expiring an already submitted form should fail.""" - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]}, - rendered_content="
Test form
", - ) - form.submit(data={}, action="submit") - - with pytest.raises(ValueError, match="Form cannot be expired in status: submitted"): - form.expire() - - def test_get_form_definition_for_display(self): - """Test getting form definition for display.""" - form_definition = { - "inputs": [{"type": "text", "name": "field1"}], - "user_actions": [{"id": "submit", "title": "Submit"}], - } - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition=form_definition, - rendered_content="
Test form
", - ) - - result = form.get_form_definition_for_display() - - assert result["form_content"] == "
Test form
" - assert result["inputs"] == form_definition["inputs"] - assert result["user_actions"] == form_definition["user_actions"] - assert "site" not in result - - def test_get_form_definition_for_display_with_site_info(self): - """Test getting form definition for display with site info.""" - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition={"inputs": [], "user_actions": []}, - rendered_content="
Test form
", - ) - - result = form.get_form_definition_for_display(include_site_info=True) - - assert "site" in result - assert result["site"]["title"] == "Workflow Form" - - def test_get_form_definition_expired_form(self): - """Test getting form definition for expired form should fail.""" - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition={"inputs": [], "user_actions": []}, - rendered_content="
Test form
", - ) - form.expire() - - with pytest.raises(ValueError, match="Form has expired"): - form.get_form_definition_for_display() - - def test_get_form_definition_submitted_form(self): - """Test getting form definition for submitted form should fail.""" - form = HumanInputForm.create( - id_="test-form-id", - workflow_run_id="test-workflow-run", - form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]}, - rendered_content="
Test form
", - ) - form.submit(data={}, action="submit") - - with pytest.raises(ValueError, match="Form has already been submitted"): - form.get_form_definition_for_display() - - -class TestSQLAlchemyHumanInputFormRepository: - """Test cases for SQLAlchemyHumanInputFormRepository.""" - - @pytest.fixture - def mock_session_factory(self): - """Create a mock session factory.""" - session = MagicMock() - session_factory = MagicMock() - session_factory.return_value.__enter__.return_value = session - session_factory.return_value.__exit__.return_value = None - return session_factory - - @pytest.fixture - def mock_user(self): - """Create a mock user.""" - user = MagicMock() - user.current_tenant_id = "test-tenant-id" - user.id = "test-user-id" - return user - - @pytest.fixture - def repository(self, mock_session_factory, mock_user): - """Create a repository instance.""" - return SQLAlchemyHumanInputFormRepository( - session_factory=mock_session_factory, - user=mock_user, - app_id="test-app-id", - ) - - def test_to_domain_model(self, repository): - """Test converting DB model to domain model.""" - from models.human_input import ( - HumanInputForm as DBForm, - ) - from models.human_input import ( - HumanInputFormStatus as DBStatus, - ) - from models.human_input import ( - HumanInputSubmissionType as DBSubmissionType, - ) - - db_form = DBForm() - db_form.id = "test-id" - db_form.workflow_run_id = "test-workflow" - db_form.form_definition = json.dumps({"inputs": [], "user_actions": []}) - db_form.rendered_content = "
Test
" - db_form.status = DBStatus.WAITING - db_form.web_app_token = "test-token" - db_form.created_at = datetime.utcnow() - db_form.submitted_data = json.dumps({"field": "value"}) - db_form.submitted_at = datetime.utcnow() - db_form.submission_type = DBSubmissionType.web_form - db_form.submission_user_id = "user123" - - domain_form = repository._to_domain_model(db_form) - - assert domain_form.id_ == "test-id" - assert domain_form.workflow_run_id == "test-workflow" - assert domain_form.form_definition == {"inputs": [], "user_actions": []} - assert domain_form.rendered_content == "
Test
" - assert domain_form.status == HumanInputFormStatus.WAITING - assert domain_form.web_app_token == "test-token" - assert domain_form.submission is not None - assert domain_form.submission.data == {"field": "value"} - assert domain_form.submission.submission_user_id == "user123" - - def test_to_db_model(self, repository): - """Test converting domain model to DB model.""" - from models.human_input import ( - HumanInputFormStatus as DBStatus, - ) - - domain_form = HumanInputForm.create( - id_="test-id", - workflow_run_id="test-workflow", - form_definition={"inputs": [], "user_actions": []}, - rendered_content="
Test
", - web_app_token="test-token", - ) - - db_form = repository._to_db_model(domain_form) - - assert db_form.id == "test-id" - assert db_form.tenant_id == "test-tenant-id" - assert db_form.app_id == "test-app-id" - assert db_form.workflow_run_id == "test-workflow" - assert json.loads(db_form.form_definition) == {"inputs": [], "user_actions": []} - assert db_form.rendered_content == "
Test
" - assert db_form.status == DBStatus.WAITING - assert db_form.web_app_token == "test-token" - - def test_save(self, repository, mock_session_factory): - """Test saving a form.""" - session = mock_session_factory.return_value.__enter__.return_value - domain_form = HumanInputForm.create( - id_="test-id", - workflow_run_id="test-workflow", - form_definition={"inputs": []}, - rendered_content="
Test
", - ) - - repository.save(domain_form) - - session.merge.assert_called_once() - session.commit.assert_called_once() - - def test_get_by_id(self, repository, mock_session_factory): - """Test getting a form by ID.""" - session = mock_session_factory.return_value.__enter__.return_value - mock_db_form = MagicMock() - mock_db_form.id = "test-id" - session.scalar.return_value = mock_db_form - - with patch.object(repository, "_to_domain_model") as mock_convert: - domain_form = HumanInputForm.create( - id_="test-id", - workflow_run_id="test-workflow", - form_definition={"inputs": []}, - rendered_content="
Test
", - ) - mock_convert.return_value = domain_form - - result = repository.get_by_id("test-id") - - assert result == domain_form - session.scalar.assert_called_once() - mock_convert.assert_called_once_with(mock_db_form) - - def test_get_by_id_not_found(self, repository, mock_session_factory): - """Test getting a non-existent form by ID.""" - session = mock_session_factory.return_value.__enter__.return_value - session.scalar.return_value = None - - with pytest.raises(ValueError, match="Human input form not found: test-id"): - repository.get_by_id("test-id") - - def test_mark_expired_forms(self, repository, mock_session_factory): - """Test marking expired forms.""" - session = mock_session_factory.return_value.__enter__.return_value - mock_forms = [MagicMock(), MagicMock(), MagicMock()] - session.scalars.return_value.all.return_value = mock_forms - - result = repository.mark_expired_forms(expiry_hours=24) - - assert result == 3 - for form in mock_forms: - assert hasattr(form, "status") - session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 649edbb37c..96b6123334 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -2,16 +2,27 @@ from __future__ import annotations +import dataclasses +from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock import pytest from core.repositories.human_input_reposotiry import ( + HumanInputFormReadRepository, + HumanInputFormRecord, HumanInputFormRepositoryImpl, _WorkspaceMemberInfo, ) -from core.workflow.nodes.human_input.entities import ExternalRecipient, MemberRecipient +from core.workflow.nodes.human_input.entities import ( + ExternalRecipient, + FormDefinition, + MemberRecipient, + TimeoutUnit, + UserAction, +) +from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, EmailMemberRecipientPayload, @@ -41,6 +52,23 @@ def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleName return created +@pytest.fixture(autouse=True) +def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None: + """Avoid SQLAlchemy mapper configuration in tests using fake sessions.""" + + class _FakeSelect: + def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def] + return self + + def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def] + return self + + monkeypatch.setattr( + "core.repositories.human_input_reposotiry.selectinload", lambda *args, **kwargs: "_loader_option" + ) + monkeypatch.setattr("core.repositories.human_input_reposotiry.select", lambda *args, **kwargs: _FakeSelect()) + + class TestHumanInputFormRepositoryImplHelpers: def test_create_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None: repo = _build_repository() @@ -125,3 +153,201 @@ class TestHumanInputFormRepositoryImplHelpers: assert len(recipients) == 2 emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients} assert emails == {"member1@example.com", "member2@example.com"} + + +def _make_form_definition() -> str: + return FormDefinition( + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + rendered_content="

hello

", + timeout=1, + timeout_unit=TimeoutUnit.HOUR, + ).model_dump_json() + + +@dataclasses.dataclass +class _DummyForm: + id: str + workflow_run_id: str + node_id: str + tenant_id: str + form_definition: str + rendered_content: str + expiration_time: datetime + selected_action_id: str | None = None + submitted_data: str | None = None + submitted_at: datetime | None = None + submission_user_id: str | None = None + submission_end_user_id: str | None = None + completed_by_recipient_id: str | None = None + + +@dataclasses.dataclass +class _DummyRecipient: + id: str + form_id: str + recipient_type: RecipientType + access_token: str + form: _DummyForm | None = None + + +class _FakeScalarResult: + def __init__(self, obj): + self._obj = obj + + def first(self): + return self._obj + + +class _FakeSession: + def __init__( + self, + *, + scalars_result=None, + forms: dict[str, _DummyForm] | None = None, + recipients: dict[str, _DummyRecipient] | None = None, + ): + self._scalars_result = scalars_result + self.forms = forms or {} + self.recipients = recipients or {} + + def scalars(self, _query): + return _FakeScalarResult(self._scalars_result) + + def get(self, model_cls, obj_id): # type: ignore[no-untyped-def] + if getattr(model_cls, "__name__", None) == "HumanInputForm": + return self.forms.get(obj_id) + if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient": + return self.recipients.get(obj_id) + return None + + def add(self, _obj): + return None + + def flush(self): + return None + + def refresh(self, _obj): + return None + + def begin(self): + return self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +def _session_factory(session: _FakeSession): + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return None + + def _factory(*_args, **_kwargs): + return _SessionContext() + + return _factory + + +class TestHumanInputFormReadRepository: + def test_get_by_token_returns_record(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-1", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id=form.id, + recipient_type=RecipientType.WEBAPP, + access_token="token-123", + form=form, + ) + session = _FakeSession(scalars_result=recipient) + repo = HumanInputFormReadRepository(_session_factory(session)) + + record = repo.get_by_token("token-123") + + assert record is not None + assert record.form_id == form.id + assert record.recipient_type == RecipientType.WEBAPP + assert record.submitted is False + + def test_get_by_form_id_and_recipient_type_uses_recipient(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-1", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id=form.id, + recipient_type=RecipientType.WEBAPP, + access_token="token-123", + form=form, + ) + session = _FakeSession(scalars_result=recipient) + repo = HumanInputFormReadRepository(_session_factory(session)) + + record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.WEBAPP) + + assert record is not None + assert record.recipient_id == recipient.id + assert record.access_token == recipient.access_token + + def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch): + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_reposotiry.naive_utc_now", lambda: fixed_now) + + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-1", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=fixed_now, + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id="form-1", + recipient_type=RecipientType.WEBAPP, + access_token="token-123", + ) + session = _FakeSession( + forms={form.id: form}, + recipients={recipient.id: recipient}, + ) + repo = HumanInputFormReadRepository(_session_factory(session)) + + record: HumanInputFormRecord = repo.mark_submitted( + form_id=form.id, + recipient_id=recipient.id, + selected_action_id="approve", + form_data={"field": "value"}, + submission_user_id="user-1", + submission_end_user_id="end-user-1", + ) + + assert form.selected_action_id == "approve" + assert form.completed_by_recipient_id == recipient.id + assert form.submission_user_id == "user-1" + assert form.submission_end_user_id == "end-user-1" + assert form.submitted_at == fixed_now + assert record.submitted is True + assert record.selected_action_id == "approve" + assert record.submitted_data == {"field": "value"} diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py index 7784bbdd2f..d4dc91c4b4 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py @@ -28,10 +28,11 @@ class TestPauseReasonDiscriminator: { "reason": { "TYPE": "human_input_required", - "human_input_form_id": "form_id", + "form_id": "form_id", + "form_content": "form_content", }, }, - HumanInputRequired(human_input_form_id="form_id"), + HumanInputRequired(form_id="form_id", form_content="form_content"), id="HumanInputRequired", ), pytest.param( @@ -56,7 +57,7 @@ class TestPauseReasonDiscriminator: @pytest.mark.parametrize( "reason", [ - HumanInputRequired(human_input_form_id="form_id"), + HumanInputRequired(form_id="form_id", form_content="form_content"), SchedulingPause(message="Hold on"), ], ids=lambda x: type(x).__name__, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index ed550379b5..02f20413e0 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -764,203 +764,3 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered(): assert partial_event.outputs.get("answer") == "fallback response" assert not any(isinstance(event, GraphRunSucceededEvent) for event in events) - - -def test_suspend_and_resume(): - graph_config = { - "edges": [ - { - "data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"}, - "id": "1753041723554-source-1753041730748-target", - "source": "1753041723554", - "sourceHandle": "source", - "target": "1753041730748", - "targetHandle": "target", - "type": "custom", - "zIndex": 0, - }, - { - "data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"}, - "id": "1753041730748-true-answer-target", - "source": "1753041730748", - "sourceHandle": "true", - "target": "answer", - "targetHandle": "target", - "type": "custom", - "zIndex": 0, - }, - { - "data": { - "isInIteration": False, - "isInLoop": False, - "sourceType": "if-else", - "targetType": "answer", - }, - "id": "1753041730748-false-1753041952799-target", - "source": "1753041730748", - "sourceHandle": "false", - "target": "1753041952799", - "targetHandle": "target", - "type": "custom", - "zIndex": 0, - }, - ], - "nodes": [ - { - "data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []}, - "height": 54, - "id": "1753041723554", - "position": {"x": 32, "y": 282}, - "positionAbsolute": {"x": 32, "y": 282}, - "selected": False, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244, - }, - { - "data": { - "cases": [ - { - "case_id": "true", - "conditions": [ - { - "comparison_operator": "contains", - "id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d", - "value": "a", - "varType": "string", - "variable_selector": ["sys", "query"], - } - ], - "id": "true", - "logical_operator": "and", - } - ], - "desc": "", - "selected": False, - "title": "IF/ELSE", - "type": "if-else", - }, - "height": 126, - "id": "1753041730748", - "position": {"x": 368, "y": 282}, - "positionAbsolute": {"x": 368, "y": 282}, - "selected": False, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244, - }, - { - "data": { - "answer": "A", - "desc": "", - "selected": False, - "title": "Answer A", - "type": "answer", - "variables": [], - }, - "height": 102, - "id": "answer", - "position": {"x": 746, "y": 282}, - "positionAbsolute": {"x": 746, "y": 282}, - "selected": False, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244, - }, - { - "data": { - "answer": "Else", - "desc": "", - "selected": False, - "title": "Answer Else", - "type": "answer", - "variables": [], - }, - "height": 102, - "id": "1753041952799", - "position": {"x": 746, "y": 426}, - "positionAbsolute": {"x": 746, "y": 426}, - "selected": True, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244, - }, - ], - "viewport": {"x": -420, "y": -76.5, "zoom": 1}, - } - graph = Graph.init(graph_config) - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="aaa", - files=[], - query="hello", - conversation_id="abababa", - ), - user_inputs={"uid": "takato"}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - ) - graph_init_params = GraphInitParams( - tenant_id="111", - app_id="222", - workflow_type=WorkflowType.CHAT, - workflow_id="333", - graph_config=graph_config, - user_id="444", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - max_execution_steps=500, - max_execution_time=1200, - ) - - _IF_ELSE_NODE_ID = "1753041730748" - - def command_source(params: CommandParams) -> CommandTypes: - # requires the engine to suspend before the execution - # of If-Else node. - if params.next_node.node_id == _IF_ELSE_NODE_ID: - return SuspendCommand() - else: - return ContinueCommand() - - graph_engine = GraphEngine( - graph=graph, - graph_runtime_state=graph_runtime_state, - graph_init_params=graph_init_params, - command_source=command_source, - ) - events = list(graph_engine.run()) - last_event = events[-1] - assert isinstance(last_event, GraphRunSuspendedEvent) - assert last_event.current_node_id == _IF_ELSE_NODE_ID - state = graph_engine.save() - assert state != "" - - engine2 = GraphEngine.resume( - state=state, - graph=graph, - ) - events = list(engine2.run()) - assert isinstance(events[-1], GraphRunSucceededEvent) - node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)] - assert node_run_succeeded_events - start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"] - assert not start_events - ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID] - assert ifelse_succeeded_events - answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"] - assert answer_else_events - assert answer_else_events[0].route_node_state.node_run_result.outputs == { - "answer": "Else", - "files": ArrayFileSegment(value=[]), - } - - answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"] - assert not answer_a_events diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index c398e4e8c1..8aa04a448c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -17,8 +17,8 @@ from core.workflow.graph_events import ( from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input import HumanInputNode from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index ece69b080b..47e3412b74 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -16,8 +16,8 @@ from core.workflow.graph_events import ( from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input import HumanInputNode from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py new file mode 100644 index 0000000000..5b3e8ad76a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -0,0 +1,185 @@ +import time +from collections.abc import Generator, Mapping +from typing import Any + +import core.workflow.nodes.human_input.entities # noqa: F401 +from core.workflow.entities import GraphInitParams +from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_events import ( + GraphEngineEvent, + GraphRunPausedEvent, + GraphRunSucceededEvent, + NodeRunSucceededEvent, +) +from core.workflow.node_events import NodeEventBase, NodeRunResult, PauseRequestedEvent +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + + +class _PausingNodeData(BaseNodeData): + pass + + +class _PausingNode(Node): + node_type = NodeType.TOOL + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = _PausingNodeData.model_validate(data) + + def _get_error_strategy(self): + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> str | None: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + + @staticmethod + def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]: + yield event + + def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: + resumed_flag = self.graph_runtime_state.variable_pool.get((self.id, "resumed")) + if resumed_flag is None: + # mark as resumed and request pause + self.graph_runtime_state.variable_pool.add((self.id, "resumed"), True) + return self._pause_generator(PauseRequestedEvent(reason=SchedulingPause(message="test pause"))) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"value": "completed"}, + ) + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="service-api", + call_depth=0, + ) + + start_data = StartNodeData(title="start", variables=[]) + start_node = StartNode( + id="start", + config={"id": "start", "data": start_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + start_node.init_node_data(start_data.model_dump()) + + pause_data = _PausingNodeData(title="pausing") + pause_node = _PausingNode( + id="pausing", + config={"id": "pausing", "data": pause_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + pause_node.init_node_data(pause_data.model_dump()) + + end_data = EndNodeData( + title="end", + outputs=[ + VariableSelector(variable="result", value_selector=["pausing", "value"]), + ], + desc=None, + ) + end_node = EndNode( + id="end", + config={"id": "end", "data": end_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + end_node.init_node_data(end_data.model_dump()) + + return Graph.new().add_root(start_node).add_node(pause_node).add_node(end_node).build() + + +def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + return list(engine.run()) + + +def _node_successes(events: list[GraphEngineEvent]) -> list[str]: + return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] + + +def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: + segment = variable_pool.get(selector) + assert segment is not None + return getattr(segment, "value", segment) + + +def test_engine_resume_restores_state_and_completion(): + # Baseline run without pausing + baseline_state = _build_runtime_state() + baseline_graph = _build_pausing_graph(baseline_state) + baseline_state.variable_pool.add(("pausing", "resumed"), True) + baseline_events = _run_graph(baseline_graph, baseline_state) + assert isinstance(baseline_events[-1], GraphRunSucceededEvent) + baseline_success_nodes = _node_successes(baseline_events) + + # Run with pause + paused_state = _build_runtime_state() + paused_graph = _build_pausing_graph(paused_state) + paused_events = _run_graph(paused_graph, paused_state) + assert isinstance(paused_events[-1], GraphRunPausedEvent) + snapshot = paused_state.dumps() + + # Resume from snapshot + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resumed_graph = _build_pausing_graph(resumed_state) + resumed_events = _run_graph(resumed_graph, resumed_state) + assert isinstance(resumed_events[-1], GraphRunSucceededEvent) + + combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) + assert combined_success_nodes == baseline_success_nodes + + assert baseline_state.outputs == resumed_state.outputs + assert _segment_value(baseline_state.variable_pool, ("pausing", "resumed")) == _segment_value( + resumed_state.variable_pool, ("pausing", "resumed") + ) + assert _segment_value(baseline_state.variable_pool, ("pausing", "value")) == _segment_value( + resumed_state.variable_pool, ("pausing", "value") + ) + assert baseline_state.graph_execution.completed + assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_node.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_node.py deleted file mode 100644 index 57e3c1f69a..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_node.py +++ /dev/null @@ -1,283 +0,0 @@ -""" -Unit tests for human input node implementation. -""" - -import uuid -from unittest.mock import Mock, patch - -import pytest - -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.nodes.human_input import HumanInputNode, HumanInputNodeData - - -class TestHumanInputNode: - """Test HumanInputNode implementation.""" - - @pytest.fixture - def mock_graph_init_params(self): - """Create mock graph initialization parameters.""" - mock_params = Mock() - mock_params.tenant_id = "tenant-123" - mock_params.app_id = "app-456" - mock_params.user_id = "user-789" - mock_params.user_from = "web" - mock_params.invoke_from = "web_app" - mock_params.call_depth = 0 - return mock_params - - @pytest.fixture - def mock_graph(self): - """Create mock graph.""" - return Mock() - - @pytest.fixture - def mock_graph_runtime_state(self): - """Create mock graph runtime state.""" - return Mock() - - @pytest.fixture - def sample_node_config(self): - """Create sample node configuration.""" - return { - "id": "human_input_123", - "data": { - "title": "User Confirmation", - "desc": "Please confirm the action", - "delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}], - "form_content": "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}", - "inputs": [ - { - "type": "text-input", - "output_variable_name": "confirmation", - "placeholder": {"type": "constant", "value": "Type 'yes' to confirm"}, - } - ], - "user_actions": [ - {"id": "confirm", "title": "Confirm", "button_style": "primary"}, - {"id": "cancel", "title": "Cancel", "button_style": "default"}, - ], - "timeout": 24, - "timeout_unit": "hour", - }, - } - - @pytest.fixture - def human_input_node(self, sample_node_config, mock_graph_init_params, mock_graph, mock_graph_runtime_state): - """Create HumanInputNode instance.""" - node = HumanInputNode( - id="node_123", - config=sample_node_config, - graph_init_params=mock_graph_init_params, - graph=mock_graph, - graph_runtime_state=mock_graph_runtime_state, - ) - return node - - def test_node_initialization(self, human_input_node): - """Test node initialization.""" - assert human_input_node.node_id == "human_input_123" - assert human_input_node.tenant_id == "tenant-123" - assert human_input_node.app_id == "app-456" - assert isinstance(human_input_node.node_data, HumanInputNodeData) - assert human_input_node.node_data.title == "User Confirmation" - - def test_node_type_and_version(self, human_input_node): - """Test node type and version.""" - assert human_input_node.type_.value == "human_input" - assert human_input_node.version() == "1" - - def test_node_properties(self, human_input_node): - """Test node properties access.""" - assert human_input_node.title == "User Confirmation" - assert human_input_node.description == "Please confirm the action" - assert human_input_node.error_strategy is None - assert human_input_node.retry_config.retry_enabled is False - - @patch("uuid.uuid4") - def test_node_run_success(self, mock_uuid, human_input_node): - """Test successful node execution.""" - # Setup mocks - mock_form_id = uuid.UUID("12345678-1234-5678-9abc-123456789012") - mock_token = uuid.UUID("87654321-4321-8765-cba9-876543210987") - mock_uuid.side_effect = [mock_form_id, mock_token] - - # Execute the node - result = human_input_node._run() - - # Verify result - assert result.status == WorkflowNodeExecutionStatus.RUNNING - assert result.metadata["suspended"] is True - assert result.metadata["form_id"] == str(mock_form_id) - assert result.metadata["web_app_form_token"] == str(mock_token).replace("-", "") - - # Verify event data in metadata - human_input_event = result.metadata["human_input_event"] - assert human_input_event["form_id"] == str(mock_form_id) - assert human_input_event["node_id"] == "human_input_123" - assert human_input_event["form_content"] == "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}" - assert len(human_input_event["inputs"]) == 1 - - suspended_event = result.metadata["suspended_event"] - assert suspended_event["suspended_at_node_ids"] == ["human_input_123"] - - def test_node_run_without_webapp_delivery(self, human_input_node): - """Test node execution without webapp delivery method.""" - # Modify node data to disable webapp delivery - human_input_node.node_data.delivery_methods[0].enabled = False - - result = human_input_node._run() - - # Should still work, but without web app token - assert result.status == WorkflowNodeExecutionStatus.RUNNING - assert result.metadata["web_app_form_token"] is None - - def test_resume_from_human_input_success(self, human_input_node): - """Test successful resume from human input.""" - form_submission_data = {"inputs": {"confirmation": "yes"}, "action": "confirm"} - - result = human_input_node.resume_from_human_input(form_submission_data) - - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["confirmation"] == "yes" - assert result.outputs["_action"] == "confirm" - assert result.metadata["form_submitted"] is True - assert result.metadata["submitted_action"] == "confirm" - - def test_resume_from_human_input_partial_inputs(self, human_input_node): - """Test resume with partial inputs.""" - form_submission_data = { - "inputs": {}, # Empty inputs - "action": "cancel", - } - - result = human_input_node.resume_from_human_input(form_submission_data) - - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "confirmation" not in result.outputs # Field not provided - assert result.outputs["_action"] == "cancel" - - def test_resume_from_human_input_missing_data(self, human_input_node): - """Test resume with missing submission data.""" - form_submission_data = {} # Missing required fields - - result = human_input_node.resume_from_human_input(form_submission_data) - - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["_action"] == "" # Default empty action - - def test_get_default_config(self): - """Test getting default configuration.""" - config = HumanInputNode.get_default_config() - - assert config["type"] == "human_input" - assert "config" in config - config_data = config["config"] - - assert len(config_data["delivery_methods"]) == 1 - assert config_data["delivery_methods"][0]["type"] == "webapp" - assert config_data["delivery_methods"][0]["enabled"] is True - - assert config_data["form_content"] == "# Human Input\n\nPlease provide your input:\n\n{{#$output.input#}}" - assert len(config_data["inputs"]) == 1 - assert config_data["inputs"][0]["output_variable_name"] == "input" - - assert len(config_data["user_actions"]) == 1 - assert config_data["user_actions"][0]["id"] == "submit" - - assert config_data["timeout"] == 24 - assert config_data["timeout_unit"] == "hour" - - def test_process_form_content(self, human_input_node): - """Test form content processing.""" - # This is a placeholder test since the actual variable substitution - # logic is marked as TODO in the implementation - processed_content = human_input_node._process_form_content() - - # For now, should return the raw content - expected_content = "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}" - assert processed_content == expected_content - - def test_extract_variable_selector_mapping(self): - """Test variable selector extraction.""" - graph_config = {} - node_data = { - "form_content": "Hello {{#node_123.output#}}", - "inputs": [ - { - "type": "text-input", - "output_variable_name": "test", - "placeholder": {"type": "variable", "selector": ["node_456", "var_name"]}, - } - ], - } - - # This is a placeholder test since the actual extraction logic - # is marked as TODO in the implementation - mapping = HumanInputNode._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id="test_node", node_data=node_data - ) - - # For now, should return empty dict - assert mapping == {} - - -class TestHumanInputNodeValidation: - """Test validation scenarios for HumanInputNode.""" - - def test_node_with_invalid_config(self): - """Test node creation with invalid configuration.""" - invalid_config = { - "id": "test_node", - "data": { - "title": "Test", - "delivery_methods": [ - { - "type": "invalid_type", # Invalid delivery method type - "enabled": True, - "config": {}, - } - ], - }, - } - - mock_params = Mock() - mock_params.tenant_id = "tenant-123" - mock_params.app_id = "app-456" - mock_params.user_id = "user-789" - mock_params.user_from = "web" - mock_params.invoke_from = "web_app" - mock_params.call_depth = 0 - - with pytest.raises(ValueError): - HumanInputNode( - id="node_123", - config=invalid_config, - graph_init_params=mock_params, - graph=Mock(), - graph_runtime_state=Mock(), - ) - - def test_node_with_missing_node_id(self): - """Test node creation with missing node ID in config.""" - invalid_config = { - # Missing "id" field - "data": {"title": "Test"} - } - - mock_params = Mock() - mock_params.tenant_id = "tenant-123" - mock_params.app_id = "app-456" - mock_params.user_id = "user-789" - mock_params.user_from = "web" - mock_params.invoke_from = "web_app" - mock_params.call_depth = 0 - - with pytest.raises(ValueError, match="Node ID is required"): - HumanInputNode( - id="node_123", - config=invalid_config, - graph_init_params=mock_params, - graph=Mock(), - graph_runtime_state=Mock(), - ) diff --git a/api/tests/unit_tests/libs/_human_input/__init__.py b/api/tests/unit_tests/libs/_human_input/__init__.py index 26ca83c9e7..66714e72f8 100644 --- a/api/tests/unit_tests/libs/_human_input/__init__.py +++ b/api/tests/unit_tests/libs/_human_input/__init__.py @@ -1 +1 @@ -# Unit tests for human input library +# Treat this directory as a package so support modules can be imported relatively. diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py new file mode 100644 index 0000000000..853b79e90b --- /dev/null +++ b/api/tests/unit_tests/libs/_human_input/support.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Optional + +from core.workflow.nodes.human_input.entities import FormInput, TimeoutUnit + + +# Exceptions +class HumanInputError(Exception): + error_code: str = "unknown" + + def __init__(self, message: str = "", error_code: str | None = None): + super().__init__(message) + self.message = message or self.__class__.__name__ + if error_code: + self.error_code = error_code + + +class FormNotFoundError(HumanInputError): + error_code = "form_not_found" + + +class FormExpiredError(HumanInputError): + error_code = "human_input_form_expired" + + +class FormAlreadySubmittedError(HumanInputError): + error_code = "human_input_form_submitted" + + +class InvalidFormDataError(HumanInputError): + error_code = "invalid_form_data" + + +# Models +@dataclass +class HumanInputForm: + form_id: str + workflow_run_id: str + node_id: str + tenant_id: str + app_id: str | None + form_content: str + inputs: list[FormInput] + user_actions: list[dict[str, Any]] + timeout: int + timeout_unit: TimeoutUnit + web_app_form_token: str | None = None + created_at: datetime = field(default_factory=datetime.utcnow) + expires_at: datetime | None = None + submitted_at: datetime | None = None + submitted_data: dict[str, Any] | None = None + submitted_action: str | None = None + + def __post_init__(self) -> None: + if self.expires_at is None: + self.calculate_expiration() + + @property + def is_expired(self) -> bool: + return self.expires_at is not None and datetime.utcnow() > self.expires_at + + @property + def is_submitted(self) -> bool: + return self.submitted_at is not None + + def mark_submitted(self, inputs: dict[str, Any], action: str) -> None: + self.submitted_data = inputs + self.submitted_action = action + self.submitted_at = datetime.utcnow() + + def submit(self, inputs: dict[str, Any], action: str) -> None: + self.mark_submitted(inputs, action) + + def calculate_expiration(self) -> None: + start = self.created_at + if self.timeout_unit == TimeoutUnit.HOUR: + self.expires_at = start + timedelta(hours=self.timeout) + elif self.timeout_unit == TimeoutUnit.DAY: + self.expires_at = start + timedelta(days=self.timeout) + else: + raise ValueError(f"Unsupported timeout unit {self.timeout_unit}") + + def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]: + inputs_response = [ + { + "type": form_input.type.name.lower().replace("_", "-"), + "output_variable_name": form_input.output_variable_name, + } + for form_input in self.inputs + ] + response = { + "form_content": self.form_content, + "inputs": inputs_response, + "user_actions": self.user_actions, + } + if include_site_info: + response["site"] = {"app_id": self.app_id, "title": "Workflow Form"} + return response + + +@dataclass +class FormSubmissionData: + form_id: str + inputs: dict[str, Any] + action: str + submitted_at: datetime = field(default_factory=datetime.utcnow) + + @classmethod + def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore + return cls(form_id=form_id, inputs=request.inputs, action=request.action) + + +@dataclass +class FormSubmissionRequest: + inputs: dict[str, Any] + action: str + + +# Repository +class InMemoryFormRepository: + """ + Simple in-memory repository used by unit tests. + """ + + def __init__(self): + self._forms: dict[str, HumanInputForm] = {} + + @property + def forms(self) -> dict[str, HumanInputForm]: + return self._forms + + def save(self, form: HumanInputForm) -> None: + self._forms[form.form_id] = form + + def get_by_id(self, form_id: str) -> Optional[HumanInputForm]: + return self._forms.get(form_id) + + def get_by_token(self, token: str) -> Optional[HumanInputForm]: + for form in self._forms.values(): + if form.web_app_form_token == token: + return form + return None + + def delete(self, form_id: str) -> None: + self._forms.pop(form_id, None) + + +# Service +class FormService: + """Service layer for managing human input forms in tests.""" + + def __init__(self, repository: InMemoryFormRepository): + self.repository = repository + + def create_form( + self, + *, + form_id: str, + workflow_run_id: str, + node_id: str, + tenant_id: str, + app_id: str | None, + form_content: str, + inputs, + user_actions, + timeout: int, + timeout_unit: TimeoutUnit, + web_app_form_token: str | None = None, + ) -> HumanInputForm: + form = HumanInputForm( + form_id=form_id, + workflow_run_id=workflow_run_id, + node_id=node_id, + tenant_id=tenant_id, + app_id=app_id, + form_content=form_content, + inputs=list(inputs), + user_actions=[{"id": action.id, "title": action.title} for action in user_actions], + timeout=timeout, + timeout_unit=timeout_unit, + web_app_form_token=web_app_form_token, + ) + form.calculate_expiration() + self.repository.save(form) + return form + + def get_form_by_id(self, form_id: str) -> HumanInputForm: + form = self.repository.get_by_id(form_id) + if form is None: + raise FormNotFoundError() + return form + + def get_form_by_token(self, token: str) -> HumanInputForm: + form = self.repository.get_by_token(token) + if form is None: + raise FormNotFoundError() + return form + + def get_form_definition(self, form_id: str, *, is_token: bool) -> dict: + form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) + if form.is_expired: + raise FormExpiredError() + if form.is_submitted: + raise FormAlreadySubmittedError() + + definition = { + "form_content": form.form_content, + "inputs": form.inputs, + "user_actions": form.user_actions, + } + if is_token: + definition["site"] = {"title": "Workflow Form"} + return definition + + def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None: + form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) + if form.is_expired: + raise FormExpiredError() + if form.is_submitted: + raise FormAlreadySubmittedError() + + self._validate_submission(form=form, submission_data=submission_data) + form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action) + self.repository.save(form) + + def cleanup_expired_forms(self) -> int: + expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired] + for form_id in expired_ids: + self.repository.delete(form_id) + return len(expired_ids) + + def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None: + defined_actions = {action["id"] for action in form.user_actions} + if submission_data.action not in defined_actions: + raise InvalidFormDataError(f"Invalid action: {submission_data.action}") + + missing_inputs = [] + for form_input in form.inputs: + if form_input.output_variable_name not in submission_data.inputs: + missing_inputs.append(form_input.output_variable_name) + + if missing_inputs: + raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}") + + # Extra inputs are allowed; no further validation required. diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py index c10e179157..51414a9d39 100644 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -5,15 +5,6 @@ Unit tests for FormService. from datetime import datetime, timedelta import pytest -from libs._human_input.exceptions import ( - FormAlreadySubmittedError, - FormExpiredError, - FormNotFoundError, - InvalidFormDataError, -) -from libs._human_input.form_service import FormService -from libs._human_input.models import FormSubmissionData -from libs._human_input.repository import InMemoryFormRepository from core.workflow.nodes.human_input.entities import ( FormInput, @@ -22,6 +13,16 @@ from core.workflow.nodes.human_input.entities import ( UserAction, ) +from .support import ( + FormAlreadySubmittedError, + FormExpiredError, + FormNotFoundError, + FormService, + FormSubmissionData, + InMemoryFormRepository, + InvalidFormDataError, +) + class TestFormService: """Test FormService functionality.""" diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py index f4851c1be6..66c34d577c 100644 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -5,16 +5,16 @@ Unit tests for human input form models. from datetime import datetime, timedelta import pytest -from libs._human_input.models import FormSubmissionData, HumanInputForm from core.workflow.nodes.human_input.entities import ( FormInput, FormInputType, - FormSubmissionRequest, TimeoutUnit, UserAction, ) +from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm + class TestHumanInputForm: """Test HumanInputForm model.""" diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py new file mode 100644 index 0000000000..ed9569be4b --- /dev/null +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -0,0 +1,205 @@ +import dataclasses +from datetime import datetime +from unittest.mock import MagicMock + +import pytest + +from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord +from core.workflow.nodes.human_input.entities import FormDefinition, TimeoutUnit, UserAction +from models.account import Account +from models.human_input import RecipientType +from services.human_input_service import FormSubmittedError, HumanInputService + + +@pytest.fixture +def mock_session_factory(): + session = MagicMock() + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = None + + factory = MagicMock() + factory.return_value = session_cm + return factory, session + + +@pytest.fixture +def sample_form_record(): + return HumanInputFormRecord( + form_id="form-id", + workflow_run_id="workflow-run-id", + node_id="node-id", + tenant_id="tenant-id", + definition=FormDefinition( + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + rendered_content="

hello

", + timeout=1, + timeout_unit=TimeoutUnit.HOUR, + ), + rendered_content="

hello

", + expiration_time=datetime(2024, 1, 1), + selected_action_id=None, + submitted_data=None, + submitted_at=None, + submission_user_id=None, + submission_end_user_id=None, + completed_by_recipient_id=None, + recipient_id="recipient-id", + recipient_type=RecipientType.WEBAPP, + access_token="token", + ) + + +def test_enqueue_resume_dispatches_task(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + trigger_log = MagicMock() + trigger_log.id = "trigger-log-id" + trigger_log.queue_name = "workflow_queue" + + repo_cls = mocker.patch( + "services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository", + autospec=True, + ) + repo = repo_cls.return_value + repo.get_by_workflow_run_id.return_value = trigger_log + + resume_task = mocker.patch("services.human_input_service.resume_workflow_execution") + + service._enqueue_resume("workflow-run-id") + + repo_cls.assert_called_once_with(session) + resume_task.apply_async.assert_called_once() + call_kwargs = resume_task.apply_async.call_args.kwargs + assert call_kwargs["queue"] == "workflow_queue" + payload = call_kwargs["kwargs"]["task_data_dict"] + assert payload["workflow_trigger_log_id"] == "trigger-log-id" + assert payload["workflow_run_id"] == "workflow-run-id" + + +def test_enqueue_resume_no_trigger_log(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + repo_cls = mocker.patch( + "services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository", + autospec=True, + ) + repo = repo_cls.return_value + repo.get_by_workflow_run_id.return_value = None + + resume_task = mocker.patch("services.human_input_service.resume_workflow_execution") + + service._enqueue_resume("workflow-run-id") + + repo_cls.assert_called_once_with(session) + resume_task.apply_async.assert_not_called() + + +def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + repo_cls = mocker.patch( + "services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository", + autospec=True, + ) + repo = repo_cls.return_value + repo.get_by_workflow_run_id.return_value = None + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + app = MagicMock() + app.mode = "advanced-chat" + + session.get.side_effect = [workflow_run, app] + + resume_task = mocker.patch("services.human_input_service.resume_chatflow_execution") + + service._enqueue_resume("workflow-run-id") + + resume_task.apply_async.assert_called_once() + call_kwargs = resume_task.apply_async.call_args.kwargs + assert call_kwargs["queue"] == "chatflow_execute" + assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" + + +def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormReadRepository) + repo.get_by_form_id_and_recipient_type.return_value = sample_form_record + + service = HumanInputService(session_factory, form_repository=repo) + form = service.get_form_definition_by_id("form-id") + + repo.get_by_form_id_and_recipient_type.assert_called_once_with( + form_id="form-id", + recipient_type=RecipientType.WEBAPP, + ) + assert form is not None + assert form.get_definition() == sample_form_record.definition + + +def test_get_form_definition_by_id_raises_on_submitted(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime(2024, 1, 1)) + repo = MagicMock(spec=HumanInputFormReadRepository) + repo.get_by_form_id_and_recipient_type.return_value = submitted_record + + service = HumanInputService(session_factory, form_repository=repo) + + with pytest.raises(FormSubmittedError): + service.get_form_definition_by_id("form-id") + + +def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormReadRepository) + repo.get_by_token.return_value = sample_form_record + repo.mark_submitted.return_value = sample_form_record + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "_enqueue_resume") + + service.submit_form_by_token( + recipient_type=RecipientType.WEBAPP, + form_token="token", + selected_action_id="approve", + form_data={"field": "value"}, + submission_end_user_id="end-user-id", + ) + + repo.get_by_token.assert_called_once_with("token") + repo.mark_submitted.assert_called_once() + call_kwargs = repo.mark_submitted.call_args.kwargs + assert call_kwargs["form_id"] == sample_form_record.form_id + assert call_kwargs["recipient_id"] == sample_form_record.recipient_id + assert call_kwargs["selected_action_id"] == "approve" + assert call_kwargs["form_data"] == {"field": "value"} + assert call_kwargs["submission_end_user_id"] == "end-user-id" + enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) + + +def test_submit_form_by_id_passes_account(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormReadRepository) + repo.get_by_form_id_and_recipient_type.return_value = sample_form_record + repo.mark_submitted.return_value = sample_form_record + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "_enqueue_resume") + account = MagicMock(spec=Account) + account.id = "account-id" + + service.submit_form_by_id( + form_id="form-id", + selected_action_id="approve", + form_data={"x": 1}, + user=account, + ) + + repo.get_by_form_id_and_recipient_type.assert_called_once() + repo.mark_submitted.assert_called_once() + assert repo.mark_submitted.call_args.kwargs["submission_user_id"] == "account-id" + enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index f45a72927e..ded141f01a 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -35,7 +35,6 @@ class TestDataFactory: app_id: str = "app-789", workflow_id: str = "workflow-101", status: str | WorkflowExecutionStatus = "paused", - pause_id: str | None = None, **kwargs, ) -> MagicMock: """Create a mock WorkflowRun object.""" @@ -45,7 +44,6 @@ class TestDataFactory: mock_run.app_id = app_id mock_run.workflow_id = workflow_id mock_run.status = status - mock_run.pause_id = pause_id for key, value in kwargs.items(): setattr(mock_run, key, value) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 41362dccd0..9700cbaf0e 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -161,292 +161,3 @@ class TestWorkflowService: assert workflows == [] assert has_more is False mock_session.scalars.assert_called_once() - - -class TestWorkflowServiceHumanInputValidation: - @pytest.fixture - def workflow_service(self): - # Mock sessionmaker to avoid database dependency - mock_session_maker = MagicMock() - return WorkflowService(mock_session_maker) - - def test_validate_graph_structure_valid_human_input(self, workflow_service): - """Test validation of valid HumanInput node data.""" - graph = { - "nodes": [ - { - "id": "node-1", - "data": { - "type": "human_input", - "title": "Human Input", - "delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}], - "form_content": "Please provide your input", - "inputs": [ - { - "type": "text-input", - "output_variable_name": "user_input", - "placeholder": {"type": "constant", "value": "Enter text here"}, - } - ], - "user_actions": [{"id": "submit", "title": "Submit", "button_style": "primary"}], - "timeout": 24, - "timeout_unit": "hour", - }, - } - ] - } - - # Should not raise any exception - workflow_service.validate_graph_structure(graph) - - def test_validate_graph_structure_empty_graph(self, workflow_service): - """Test validation of empty graph.""" - graph = {} - - # Should not raise any exception - workflow_service.validate_graph_structure(graph) - - def test_validate_graph_structure_no_nodes(self, workflow_service): - """Test validation of graph with no nodes.""" - graph = {"nodes": []} - - # Should not raise any exception - workflow_service.validate_graph_structure(graph) - - def test_validate_graph_structure_non_human_input_node(self, workflow_service): - """Test validation ignores non-HumanInput nodes.""" - graph = {"nodes": [{"id": "node-1", "data": {"type": "start", "title": "Start"}}]} - - # Should not raise any exception - workflow_service.validate_graph_structure(graph) - - def test_validate_human_input_node_data_invalid_delivery_method_type(self, workflow_service): - """Test validation fails with invalid delivery method type.""" - graph = { - "nodes": [ - { - "id": "node-1", - "data": { - "type": "human_input", - "title": "Human Input", - "delivery_methods": [{"type": "invalid_type", "enabled": True, "config": {}}], - "form_content": "Please provide your input", - "inputs": [], - "user_actions": [], - "timeout": 24, - "timeout_unit": "hour", - }, - } - ] - } - - with pytest.raises(ValueError, match="Invalid HumanInput node data"): - workflow_service.validate_graph_structure(graph) - - def test_validate_human_input_node_data_invalid_form_input_type(self, workflow_service): - """Test validation fails with invalid form input type.""" - graph = { - "nodes": [ - { - "id": "node-1", - "data": { - "type": "human_input", - "title": "Human Input", - "delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}], - "form_content": "Please provide your input", - "inputs": [{"type": "invalid-input-type", "output_variable_name": "user_input"}], - "user_actions": [], - "timeout": 24, - "timeout_unit": "hour", - }, - } - ] - } - - with pytest.raises(ValueError, match="Invalid HumanInput node data"): - workflow_service.validate_graph_structure(graph) - - def test_validate_human_input_node_data_missing_required_fields(self, workflow_service): - """Test validation fails with missing required fields.""" - graph = { - "nodes": [ - { - "id": "node-1", - "data": { - "type": "human_input", - # Missing required fields like title - "delivery_methods": [], - "form_content": "", - "inputs": [], - "user_actions": [], - }, - } - ] - } - - with pytest.raises(ValueError, match="Invalid HumanInput node data"): - workflow_service.validate_graph_structure(graph) - - def test_validate_human_input_node_data_invalid_timeout_unit(self, workflow_service): - """Test validation fails with invalid timeout unit.""" - graph = { - "nodes": [ - { - "id": "node-1", - "data": { - "type": "human_input", - "title": "Human Input", - "delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}], - "form_content": "Please provide your input", - "inputs": [], - "user_actions": [], - "timeout": 24, - "timeout_unit": "invalid_unit", - }, - } - ] - } - - with pytest.raises(ValueError, match="Invalid HumanInput node data"): - workflow_service.validate_graph_structure(graph) - - def test_validate_human_input_node_data_invalid_button_style(self, workflow_service): - """Test validation fails with invalid button style.""" - graph = { - "nodes": [ - { - "id": "node-1", - "data": { - "type": "human_input", - "title": "Human Input", - "delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}], - "form_content": "Please provide your input", - "inputs": [], - "user_actions": [{"id": "submit", "title": "Submit", "button_style": "invalid_style"}], - "timeout": 24, - "timeout_unit": "hour", - }, - } - ] - } - - with pytest.raises(ValueError, match="Invalid HumanInput node data"): - workflow_service.validate_graph_structure(graph) - - def test_validate_human_input_node_data_email_delivery_config(self, workflow_service): - """Test validation of HumanInput node with email delivery configuration.""" - graph = { - "nodes": [ - { - "id": "node-1", - "data": { - "type": "human_input", - "title": "Human Input", - "delivery_methods": [ - { - "type": "email", - "enabled": True, - "config": { - "recipients": { - "whole_workspace": False, - "items": [{"type": "external", "email": "user@example.com"}], - }, - "subject": "Input Required", - "body": "Please provide your input", - }, - } - ], - "form_content": "Please provide your input", - "inputs": [ - { - "type": "paragraph", - "output_variable_name": "feedback", - "placeholder": {"type": "variable", "selector": ["node", "output"]}, - } - ], - "user_actions": [ - {"id": "approve", "title": "Approve", "button_style": "accent"}, - {"id": "reject", "title": "Reject", "button_style": "ghost"}, - ], - "timeout": 7, - "timeout_unit": "day", - }, - } - ] - } - - # Should not raise any exception - workflow_service.validate_graph_structure(graph) - - def test_validate_human_input_node_data_invalid_email_recipient(self, workflow_service): - """Test validation fails with invalid email recipient.""" - graph = { - "nodes": [ - { - "id": "node-1", - "data": { - "type": "human_input", - "title": "Human Input", - "delivery_methods": [ - { - "type": "email", - "enabled": True, - "config": { - "recipients": { - "whole_workspace": False, - "items": [{"type": "invalid_recipient_type", "email": "user@example.com"}], - }, - "subject": "Input Required", - "body": "Please provide your input", - }, - } - ], - "form_content": "Please provide your input", - "inputs": [], - "user_actions": [], - "timeout": 24, - "timeout_unit": "hour", - }, - } - ] - } - - with pytest.raises(ValueError, match="Invalid HumanInput node data"): - workflow_service.validate_graph_structure(graph) - - def test_validate_human_input_node_data_multiple_nodes_mixed_valid_invalid(self, workflow_service): - """Test validation with multiple nodes where some are valid and some invalid.""" - graph = { - "nodes": [ - {"id": "node-1", "data": {"type": "start", "title": "Start"}}, - { - "id": "node-2", - "data": { - "type": "human_input", - "title": "Valid Human Input", - "delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}], - "form_content": "Valid input", - "inputs": [], - "user_actions": [], - "timeout": 24, - "timeout_unit": "hour", - }, - }, - { - "id": "node-3", - "data": { - "type": "human_input", - "title": "Invalid Human Input", - "delivery_methods": [{"type": "invalid_method", "enabled": True}], - "form_content": "Invalid input", - "inputs": [], - "user_actions": [], - "timeout": 24, - "timeout_unit": "hour", - }, - }, - ] - } - - with pytest.raises(ValueError, match="Invalid HumanInput node data"): - workflow_service.validate_graph_structure(graph)