WIP: resume

This commit is contained in:
QuantumGhost
2025-11-21 10:13:20 +08:00
parent c0e15b9e1b
commit c0f1aeddbe
49 changed files with 2160 additions and 1445 deletions

View File

@ -107,6 +107,7 @@ class ConsoleHumanInputFormApi(Resource):
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
current_user, _ = current_account_with_tenant()
# Submit the form
service = HumanInputService(db.engine)
@ -114,6 +115,7 @@ class ConsoleHumanInputFormApi(Resource):
form_id=form_id,
selected_action_id=args["action"],
form_data=args["inputs"],
user=current_user,
)
return jsonify({})

View File

@ -70,6 +70,7 @@ class HumanInputFormApi(WebApiResource):
form_token=web_app_form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
submission_end_user_id=_end_user.id,
)
except FormNotFoundError:
raise NotFoundError("Form not found")

View File

@ -3,7 +3,6 @@ Web App Workflow Resume APIs.
"""
import logging
import time
from collections.abc import Generator
from flask import Response
@ -14,42 +13,6 @@ from controllers.web.wraps import WebApiResource
logger = logging.getLogger(__name__)
class WorkflowResumeWaitApi(WebApiResource):
"""API for long-polling workflow resume wait."""
def get(self, task_id: str):
"""
Get workflow execution resume notification.
GET /api/workflow/<task_id>/resume-wait
This is a long-polling API that waits for workflow to resume from paused state.
"""
# TODO: Implement actual workflow status checking
# For now, return a basic response
timeout = 30 # 30 seconds timeout for demo
start_time = time.time()
while time.time() - start_time < timeout:
# TODO: Check workflow status from database/cache
# workflow_status = workflow_service.get_status(task_id)
# For demo purposes, simulate different states
# In real implementation, this would check the actual workflow state
workflow_status = "paused" # or "running" or "ended"
if workflow_status == "running":
return {"status": "running"}, 200
elif workflow_status == "ended":
return {"status": "ended"}, 200
time.sleep(1) # Poll every second
# Return paused status if timeout reached
return {"status": "paused"}, 200
class WorkflowEventsApi(WebApiResource):
"""API for getting workflow execution events after resume."""

View File

@ -2,7 +2,7 @@ import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, TypeVar, Union, overload
from flask import Flask, current_app
@ -24,16 +24,19 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
@ -63,6 +66,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from: InvokeFrom,
workflow_run_id: uuid.UUID,
streaming: Literal[False],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@ -75,6 +79,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from: InvokeFrom,
workflow_run_id: uuid.UUID,
streaming: Literal[True],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping | str, None, None]: ...
@overload
@ -87,6 +92,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from: InvokeFrom,
workflow_run_id: uuid.UUID,
streaming: bool,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate(
@ -98,6 +104,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from: InvokeFrom,
workflow_run_id: uuid.UUID,
streaming: bool = True,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
"""
Generate App response.
@ -215,6 +222,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
pause_state_config=pause_state_config,
)
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
conversation: Conversation,
message: Message,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_runtime_state: GraphRuntimeState,
pause_state_config: PauseStateLayerConfig | None = None,
):
"""
Resume a paused advanced chat execution.
"""
return self._generate(
workflow=workflow,
user=user,
invoke_from=application_generate_entity.invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
message=message,
stream=application_generate_entity.stream,
pause_state_config=pause_state_config,
graph_runtime_state=graph_runtime_state,
)
def single_iteration_generate(
@ -395,8 +434,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Conversation | None = None,
message: Message | None = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
pause_state_config: PauseStateLayerConfig | None = None,
graph_runtime_state: GraphRuntimeState | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@ -410,12 +453,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
is_first_conversation = conversation is None
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if conversation is not None and message is not None:
pass
else:
conversation, message = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
@ -438,6 +481,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id,
)
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@ -453,6 +506,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
@ -498,6 +553,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
"""
Generate worker in a new thread.
@ -555,6 +612,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app=app,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:

View File

@ -64,6 +64,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@ -80,6 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._app = app
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@ -103,7 +105,19 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,

View File

@ -23,6 +23,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
@ -31,12 +32,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.account import Account
from models.enums import WorkflowRunTriggeredFrom
from models.model import App, EndUser
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
@ -63,6 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]: ...
@overload
@ -79,6 +84,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@ -95,6 +101,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
@ -110,6 +117,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@ -210,13 +218,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_state_config,
)
def resume(self, *, workflow_run_id: str) -> None:
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
@TBD
Resume a paused workflow execution using the persisted runtime state.
"""
pass
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=application_generate_entity.invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=application_generate_entity.stream,
variable_loader=variable_loader,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
pause_state_config=pause_state_config,
)
def _generate(
self,
@ -232,6 +267,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -245,6 +282,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
"""
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@ -253,6 +292,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_mode=app_model.mode,
)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@ -270,7 +318,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": graph_engine_layers,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
@ -372,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def single_loop_generate(
@ -453,6 +503,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def _generate_worker(
@ -466,6 +517,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
) -> None:
"""
Generate worker in a new thread.
@ -511,6 +563,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:

View File

@ -43,6 +43,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@ -56,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._root_node_id = root_node_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@ -65,17 +67,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
resume_state = self._resume_graph_runtime_state
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
@ -85,7 +82,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
inputs = self.application_generate_entity.inputs
# Create a variable pool.
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
@ -95,7 +99,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
# init graph (for both resumed and fresh runs when not single-step)
if resume_state is not None or not (
self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run
):
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel):
return self.generate_entity.entity
@dataclass(frozen=True)
class PauseStateLayerConfig:
"""Configuration container for instantiating pause persistence layers."""
session_factory: Engine | sessionmaker[Session]
state_owner_user_id: str
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(
self,

View File

@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import (
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.provider import (
LoadBalancingModelConfig,
Provider,

View File

@ -15,10 +15,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
TracingProviderEnum,
)
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@ -31,11 +28,10 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.engine import db
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
@ -470,6 +466,8 @@ class TraceTask:
@classmethod
def _get_workflow_run_repo(cls):
from repositories.factory import DifyAPIRepositoryFactory
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:

View File

@ -5,7 +5,7 @@ from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db
from models.engine import db
from models.model import Message

View File

@ -2,27 +2,17 @@
from __future__ import annotations
from importlib import import_module
from typing import Any
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
_ATTRIBUTE_MODULE_MAP = {
"CeleryWorkflowExecutionRepository": "core.repositories.celery_workflow_execution_repository",
"CeleryWorkflowNodeExecutionRepository": "core.repositories.celery_workflow_node_execution_repository",
"DifyCoreRepositoryFactory": "core.repositories.factory",
"RepositoryImportError": "core.repositories.factory",
"SQLAlchemyWorkflowNodeExecutionRepository": "core.repositories.sqlalchemy_workflow_node_execution_repository",
}
__all__ = list(_ATTRIBUTE_MODULE_MAP.keys())
def __getattr__(name: str) -> Any:
module_path = _ATTRIBUTE_MODULE_MAP.get(name)
if module_path is None:
raise AttributeError(f"module 'core.repositories' has no attribute '{name}'")
module = import_module(module_path)
return getattr(module, name)
def __dir__() -> list[str]: # pragma: no cover - simple helper
return sorted(__all__)
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowExecutionRepository",
"SQLAlchemyWorkflowNodeExecutionRepository",
]

View File

@ -1,10 +1,11 @@
import dataclasses
import json
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker, selectinload
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
@ -85,6 +86,53 @@ class _FormSubmissionImpl(FormSubmission):
return json.loads(submitted_data)
@dataclasses.dataclass(frozen=True)
class HumanInputFormRecord:
form_id: str
workflow_run_id: str
node_id: str
tenant_id: str
definition: FormDefinition
rendered_content: str
expiration_time: datetime
selected_action_id: str | None
submitted_data: Mapping[str, Any] | None
submitted_at: datetime | None
submission_user_id: str | None
submission_end_user_id: str | None
completed_by_recipient_id: str | None
recipient_id: str | None
recipient_type: RecipientType | None
access_token: str | None
@property
def submitted(self) -> bool:
return self.submitted_at is not None
@classmethod
def from_models(
cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None
) -> "HumanInputFormRecord":
return cls(
form_id=form_model.id,
workflow_run_id=form_model.workflow_run_id,
node_id=form_model.node_id,
tenant_id=form_model.tenant_id,
definition=FormDefinition.model_validate_json(form_model.form_definition),
rendered_content=form_model.rendered_content,
expiration_time=form_model.expiration_time,
selected_action_id=form_model.selected_action_id,
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
submitted_at=form_model.submitted_at,
submission_user_id=form_model.submission_user_id,
submission_end_user_id=form_model.submission_end_user_id,
completed_by_recipient_id=form_model.completed_by_recipient_id,
recipient_id=recipient_model.id if recipient_model else None,
recipient_type=recipient_model.recipient_type if recipient_model else None,
access_token=recipient_model.access_token if recipient_model else None,
)
class HumanInputFormRepositoryImpl:
def __init__(
self,
@ -275,11 +323,74 @@ class HumanInputFormRepositoryImpl:
return _FormSubmissionImpl(form_model=form_model)
def get_form_by_token(self, token: str, recipient_type: RecipientType | None = None):
class HumanInputFormReadRepository:
"""Read/write repository for fetching and submitting human input forms."""
def __init__(self, session_factory: sessionmaker | Engine):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where()
.where(HumanInputFormRecipient.access_token == form_token)
)
with self._session_factory(expire_on_commit=False) as session:
form_recipient = session.qu
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
def get_by_form_id_and_recipient_type(
self,
form_id: str,
recipient_type: RecipientType,
) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(
HumanInputFormRecipient.form_id == form_id,
HumanInputFormRecipient.recipient_type == recipient_type,
)
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
def mark_submitted(
self,
*,
form_id: str,
recipient_id: str | None,
selected_action_id: str,
form_data: Mapping[str, Any],
submission_user_id: str | None,
submission_end_user_id: str | None,
) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None
form_model.selected_action_id = selected_action_id
form_model.submitted_data = json.dumps(form_data)
form_model.submitted_at = naive_utc_now()
form_model.submission_user_id = submission_user_id
form_model.submission_end_user_id = submission_end_user_id
form_model.completed_by_recipient_id = recipient_id
session.add(form_model)
session.flush()
session.refresh(form_model)
if recipient_model is not None:
session.refresh(recipient_model)
return HumanInputFormRecord.from_models(form_model, recipient_model)

View File

@ -1,71 +0,0 @@
import abc
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import Annotated, TypeAlias, final
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.workflow.nodes.base import BaseNode
@dataclass(frozen=True)
class CommandParams:
# `next_node_instance` is the instance of the next node to run.
next_node: BaseNode
class _CommandTag(StrEnum):
SUSPEND = "suspend"
STOP = "stop"
CONTINUE = "continue"
# Note: Avoid using the `_Command` class directly.
# Instead, use `CommandTypes` for type annotations.
class _Command(BaseModel, abc.ABC):
model_config = ConfigDict(frozen=True)
tag: _CommandTag
@field_validator("tag")
@classmethod
def validate_value_type(cls, value):
if value != cls.model_fields["tag"].default:
raise ValueError("Cannot modify 'tag'")
return value
@final
class StopCommand(_Command):
tag: _CommandTag = _CommandTag.STOP
@final
class SuspendCommand(_Command):
tag: _CommandTag = _CommandTag.SUSPEND
@final
class ContinueCommand(_Command):
tag: _CommandTag = _CommandTag.CONTINUE
def _get_command_tag(command: _Command):
return command.tag
CommandTypes: TypeAlias = Annotated[
(
Annotated[StopCommand, Tag(_CommandTag.STOP)]
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
),
Discriminator(_get_command_tag),
]
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
#
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]

View File

@ -1,8 +1,3 @@
"""
Human Input node implementation.
"""
from .entities import HumanInputNodeData
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode", "HumanInputNodeData"]

View File

@ -269,16 +269,6 @@ class HumanInputNodeData(BaseNodeData):
return variable_mappings
class HumanInputRequired(BaseModel):
"""Event data for human input required."""
form_id: str
node_id: str
form_content: str
inputs: list[FormInput]
web_app_form_token: Optional[str] = None
class FormDefinition(BaseModel):
form_content: str
inputs: list[FormInput] = Field(default_factory=list)

View File

@ -1,4 +1,3 @@
import dataclasses
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any
@ -18,11 +17,6 @@ _SELECTED_BRANCH_KEY = "selected_branch"
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class _FormSubmissionResult:
action_id: str
class HumanInputNode(Node[HumanInputNodeData]):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH

View File

@ -5,7 +5,7 @@ import json
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol
from pydantic.json import pydantic_encoder
@ -13,6 +13,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.runtime.variable_pool import VariablePool
if TYPE_CHECKING:
from core.workflow.entities.pause_reason import PauseReason
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
@ -59,7 +62,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
pause_reasons: Sequence[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""

View File

@ -57,18 +57,19 @@ class HumanInputForm(DefaultFieldsMixin, Base):
completed_by_recipient_id: Mapped[str | None] = mapped_column(
StringUUID,
sa.ForeignKey("human_input_recipients.id"),
nullable=True,
)
deliveries: Mapped[list["HumanInputDelivery"]] = relationship(
"HumanInputDelivery",
primaryjoin="HumanInputForm.id == foreign(HumanInputDelivery.form_id)",
uselist=True,
back_populates="form",
lazy="raise",
)
completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship(
"HumanInputRecipient",
primaryjoin="HumanInputForm.completed_by_recipient_id == HumanInputRecipient.id",
"HumanInputFormRecipient",
primaryjoin="HumanInputForm.completed_by_recipient_id == foreign(HumanInputFormRecipient.id)",
lazy="raise",
viewonly=True,
)
@ -79,7 +80,6 @@ class HumanInputDelivery(DefaultFieldsMixin, Base):
form_id: Mapped[str] = mapped_column(
StringUUID,
sa.ForeignKey("human_input_forms.id"),
nullable=False,
)
delivery_method_type: Mapped[DeliveryMethodType] = mapped_column(
@ -91,11 +91,16 @@ class HumanInputDelivery(DefaultFieldsMixin, Base):
form: Mapped[HumanInputForm] = relationship(
"HumanInputForm",
uselist=False,
foreign_keys=[form_id],
primaryjoin="HumanInputDelivery.form_id == HumanInputForm.id",
back_populates="deliveries",
lazy="raise",
)
recipients: Mapped[list["HumanInputFormRecipient"]] = relationship(
"HumanInputRecipient",
"HumanInputFormRecipient",
primaryjoin="HumanInputDelivery.id == foreign(HumanInputFormRecipient.delivery_id)",
uselist=True,
back_populates="delivery",
# Require explicit preloading
@ -162,6 +167,7 @@ class HumanInputFormRecipient(DefaultFieldsMixin, Base):
uselist=False,
foreign_keys=[delivery_id],
back_populates="recipients",
primaryjoin="HumanInputFormRecipient.delivery_id == HumanInputDelivery.id",
# Require explicit preloading
lazy="raise",
)
@ -170,7 +176,7 @@ class HumanInputFormRecipient(DefaultFieldsMixin, Base):
"HumanInputForm",
uselist=False,
foreign_keys=[form_id],
back_populates="recipients",
primaryjoin="HumanInputFormRecipient.form_id == HumanInputForm.id",
# Require explicit preloading
lazy="raise",
)

View File

@ -6,38 +6,11 @@ by the core workflow module. These models are independent of the storage mechani
and don't contain implementation details like tenant_id, app_id, etc.
"""
import enum
from abc import ABC, abstractmethod
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated, Literal, TypeAlias
from pydantic import BaseModel, Field
from sqlalchemy import Sequence
class _PauseTypeEnum(enum.StrEnum):
human_input = enum.auto()
breakpoint = enum.auto()
scheduling = enum.auto()
class HumanInputPause(BaseModel):
type: Literal[_PauseTypeEnum.human_input] = _PauseTypeEnum.human_input
form_id: str
class SchedulingPause(BaseModel):
type: Literal[_PauseTypeEnum.scheduling] = _PauseTypeEnum.scheduling
PauseType: TypeAlias = Annotated[HumanInputPause | SchedulingPause, Field(discriminator="type")]
class PauseDetail(BaseModel):
node_id: str
node_title: str
pause_type: PauseType
from .pause_reason import PauseReason
from core.workflow.entities.pause_reason import PauseReason
@ -100,7 +73,5 @@ class WorkflowPauseEntity(ABC):
Returns a sequence of `PauseReason` objects describing the specific nodes and
reasons for which the workflow execution was paused.
This information is related to, but distinct from, the `PauseReason` type
defined in `api/core/workflow/entities/pause_reason.py`.
"""
...

View File

@ -505,7 +505,6 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Mark as resumed
pause_model.resumed_at = naive_utc_now()
workflow_run.pause_id = None # type: ignore
workflow_run.status = WorkflowExecutionStatus.RUNNING
session.add(pause_model)

View File

@ -84,3 +84,13 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
)
return list(self.session.scalars(query).all())
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
"""Get the trigger log associated with a workflow run."""
query = (
select(WorkflowTriggerLog)
.where(WorkflowTriggerLog.workflow_run_id == workflow_run_id)
.order_by(WorkflowTriggerLog.created_at.desc())
.limit(1)
)
return self.session.scalar(query)

View File

@ -109,3 +109,15 @@ class WorkflowTriggerLogRepository(Protocol):
A sequence of recent WorkflowTriggerLog instances
"""
...
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
"""
Retrieve a trigger log associated with a specific workflow run.
Args:
workflow_run_id: Identifier of the workflow run
Returns:
The matching WorkflowTriggerLog if present, None otherwise
"""
...

View File

@ -14,12 +14,8 @@ from core.app.features.rate_limiting.rate_limit import rate_limit_context
from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import ChatflowExecutionParams, chatflow_execute_task
@ -259,7 +255,7 @@ class AppGenerateService:
):
if workflow_run.status.is_ended():
# TODO(QuantumGhost): handled the ended scenario.
return
pass
generator = AdvancedChatAppGenerator()

View File

@ -1,29 +1,49 @@
import abc
import json
import logging
from collections.abc import Mapping
from typing import Any
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord
from core.workflow.nodes.human_input.entities import FormDefinition
from libs.datetime_utils import naive_utc_now
from libs.exception import BaseHTTPException
from models.account import Account
from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
from models.human_input import RecipientType
from models.model import App, AppMode
from models.workflow import WorkflowRun
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.workflow.entities import WorkflowResumeTaskData
from tasks.app_generate.workflow_execute_task import resume_chatflow_execution
from tasks.async_workflow_tasks import resume_workflow_execution
class Form:
def __init__(self, form_model: HumanInputForm):
self._form_model = form_model
def __init__(self, record: HumanInputFormRecord):
self._record = record
@abc.abstractmethod
def get_definition(self) -> FormDefinition:
pass
return self._record.definition
@abc.abstractmethod
@property
def submitted(self) -> bool:
pass
return self._record.submitted
@property
def id(self) -> str:
return self._record.form_id
@property
def workflow_run_id(self) -> str:
return self._record.workflow_run_id
@property
def recipient_id(self) -> str | None:
return self._record.recipient_id
@property
def recipient_type(self) -> RecipientType | None:
return self._record.recipient_type
class HumanInputError(Exception):
@ -49,93 +69,148 @@ class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
pass
logger = logging.getLogger(__name__)
class HumanInputService:
def __init__(
self,
session_factory: sessionmaker[Session] | Engine,
form_repository: HumanInputFormReadRepository | None = None,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._form_repository = form_repository or HumanInputFormReadRepository(session_factory)
def get_form_by_token(self, form_token: str) -> Form | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(HumanInputFormRecipient.access_token == form_token)
)
with self._session_factory(expire_on_commit=False) as session:
recipient = session.scalars(query).first()
if recipient is None:
record = self._form_repository.get_by_token(form_token)
if record is None:
return None
return Form(record)
return Form(recipient.form)
def get_form_by_id(self, form_id: str) -> Form | None:
query = select(HumanInputForm).where(HumanInputForm.id == form_id)
with self._session_factory(expire_on_commit=False) as session:
form_model = session.scalars(query).first()
if form_model is None:
return None
return Form(form_model)
def submit_form_by_id(self, form_id: str, user: Account, selected_action_id: str, form_data: Mapping[str, Any]):
recipient_query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(
HumanInputFormRecipient.recipient_type == RecipientType.WEBAPP,
HumanInputFormRecipient.form_id == form_id,
)
def get_form_by_id(self, form_id: str, recipient_type: RecipientType = RecipientType.WEBAPP) -> Form | None:
record = self._form_repository.get_by_form_id_and_recipient_type(
form_id=form_id,
recipient_type=recipient_type,
)
if record is None:
return None
return Form(record)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(recipient_query).first()
def get_form_definition_by_id(self, form_id: str) -> Form | None:
form = self.get_form_by_id(form_id, recipient_type=RecipientType.WEBAPP)
if form is None:
return None
self._ensure_not_submitted(form)
return form
if recipient_model is None:
def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None:
form = self.get_form_by_token(form_token)
if form is None or form.recipient_type != recipient_type:
return None
self._ensure_not_submitted(form)
return form
def submit_form_by_id(
self,
form_id: str,
selected_action_id: str,
form_data: Mapping[str, Any],
user: Account | None = None,
):
form = self.get_form_by_id(form_id, recipient_type=RecipientType.WEBAPP)
if form is None:
raise WebAppDeliveryNotEnabledError()
form_model = recipient_model.form
form = Form(form_model)
if form.submitted:
raise FormSubmittedError(form_model.id)
self._ensure_not_submitted(form)
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model.selected_action_id = selected_action_id
form_model.submitted_data = json.dumps(form_data)
form_model.submitted_at = naive_utc_now()
form_model.submission_user_id = user.id
form_model.completed_by_recipient_id = recipient_model.id
session.add(form_model)
# TODO: restart the execution of paused workflow
def submit_form_by_token(self, form_token: str, selected_action_id: str, form_data: Mapping[str, Any]):
recipient_query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(
HumanInputFormRecipient.form_id == form_token,
)
result = self._form_repository.mark_submitted(
form_id=form.id,
recipient_id=form.recipient_id,
selected_action_id=selected_action_id,
form_data=form_data,
submission_user_id=user.id if user else None,
submission_end_user_id=None,
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(recipient_query).first()
self._enqueue_resume(result.workflow_run_id)
if recipient_model is None:
def submit_form_by_token(
self,
recipient_type: RecipientType,
form_token: str,
selected_action_id: str,
form_data: Mapping[str, Any],
submission_end_user_id: str | None = None,
):
form = self.get_form_by_token(form_token)
if form is None or form.recipient_type != recipient_type:
raise WebAppDeliveryNotEnabledError()
form_model = recipient_model.form
form = Form(form_model)
self._ensure_not_submitted(form)
result = self._form_repository.mark_submitted(
form_id=form.id,
recipient_id=form.recipient_id,
selected_action_id=selected_action_id,
form_data=form_data,
submission_user_id=None,
submission_end_user_id=submission_end_user_id,
)
self._enqueue_resume(result.workflow_run_id)
def _ensure_not_submitted(self, form: Form) -> None:
if form.submitted:
raise FormSubmittedError(form_model.id)
raise FormSubmittedError(form.id)
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model.selected_action_id = selected_action_id
form_model.submitted_data = json.dumps(form_data)
form_model.submitted_at = naive_utc_now()
form_model.submission_user_id = user.id
def _enqueue_resume(self, workflow_run_id: str) -> None:
with self._session_factory(expire_on_commit=False) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id)
form_model.completed_by_recipient_id = recipient_model.id
session.add(form_model)
if trigger_log is not None:
payload = WorkflowResumeTaskData(
workflow_trigger_log_id=trigger_log.id,
workflow_run_id=workflow_run_id,
)
try:
resume_workflow_execution.apply_async(
kwargs={"task_data_dict": payload.model_dump()},
queue=trigger_log.queue_name,
)
except Exception: # pragma: no cover
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
return
if self._enqueue_chatflow_resume(workflow_run_id):
return
logger.warning("No workflow trigger log bound to workflow run %s; skipping resume dispatch", workflow_run_id)
def _enqueue_chatflow_resume(self, workflow_run_id: str) -> bool:
with self._session_factory(expire_on_commit=False) as session:
workflow_run = session.get(WorkflowRun, workflow_run_id)
if workflow_run is None:
return False
app = session.get(App, workflow_run.app_id)
if app is None:
return False
if app.mode != AppMode.ADVANCED_CHAT.value:
return False
try:
resume_chatflow_execution.apply_async(
kwargs={"payload": {"workflow_run_id": workflow_run_id}},
queue="chatflow_execute",
)
except Exception: # pragma: no cover
logger.exception("Failed to enqueue chatflow resume for workflow run %s", workflow_run_id)
return False
return True

View File

@ -98,6 +98,13 @@ class WorkflowTaskData(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class WorkflowResumeTaskData(BaseModel):
"""Payload for workflow resumption tasks."""
workflow_trigger_log_id: str
workflow_run_id: str
class AsyncTriggerExecutionResult(BaseModel):
"""Result from async trigger-based workflow execution"""

View File

@ -5,7 +5,6 @@ from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
import contexts
from core.workflow.entities.workflow_pause import PauseDetail
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import (
@ -160,9 +159,3 @@ class WorkflowRunService:
app_id=app_model.id,
workflow_run_id=run_id,
)
def get_pause_details(self, workflow_run_id: str) -> Sequence[PauseDetail]:
pause = self._workflow_run_repo.get_workflow_pause(workflow_run_id)
if pause is None:
return []
return pause.get_pause_details()

View File

@ -949,7 +949,7 @@ class WorkflowService:
node_data = node.get("data", {})
node_type = node_data.get("type")
if node_type == "human_input":
if node_type == NodeType.HUMAN_INPUT:
self._validate_human_input_node_data(node_data)
def validate_features_structure(self, app_model: App, features: dict):

View File

@ -1,5 +1,3 @@
from .workflow_execute_task import chatflow_execute_task
__all__ = [
"chatflow_execute_taskch"
]
__all__ = ["chatflow_execute_task"]

View File

@ -8,18 +8,21 @@ from typing import Annotated, Any, TypeAlias, Union
from celery import shared_task
from flask import current_app, json
from pydantic import BaseModel, Discriminator, Field, Tag
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_database import db
from libs.flask_utils import set_login_user
from models.account import Account
from models.model import App, AppMode, EndUser
from models.workflow import Workflow
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import App, AppMode, Conversation, EndUser, Message
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
logger = logging.getLogger(__name__)
@ -129,6 +132,11 @@ class _ChatflowRunner:
workflow = session.get(Workflow, exec_params.workflow_id)
app = session.get(App, workflow.app_id)
pause_config = PauseStateLayerConfig(
session_factory=self._session_factory,
state_owner_user_id=workflow.created_by,
)
user = self._resolve_user()
chat_generator = AdvancedChatAppGenerator()
@ -144,6 +152,7 @@ class _ChatflowRunner:
invoke_from=exec_params.invoke_from,
streaming=exec_params.streaming,
workflow_run_id=workflow_run_id,
pause_state_config=pause_config,
)
if not exec_params.streaming:
return response
@ -174,11 +183,135 @@ class _ChatflowRunner:
return user
def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None:
role = CreatorUserRole(workflow_run.created_by_role)
if role == CreatorUserRole.ACCOUNT:
user = session.get(Account, workflow_run.created_by)
if user:
user.set_tenant_id(workflow_run.tenant_id)
return user
return session.get(EndUser, workflow_run.created_by)
@shared_task(queue="chatflow_execute")
def chatflow_execute_task(payload: str) -> Mapping[str, Any] | None:
exec_params = ChatflowExecutionParams.model_validate_json(payload)
print("chatflow_execute_task run with params", exec_params)
logger.info("chatflow_execute_task run with params: %s", exec_params)
runner = _ChatflowRunner(db.engine, exec_params=exec_params)
return runner.run()
@shared_task(queue="chatflow_execute", name="resume_chatflow_execution")
def resume_chatflow_execution(payload: dict[str, Any]) -> None:
workflow_run_id = payload["workflow_run_id"]
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory)
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
if pause_entity is None:
logger.warning("No pause entity found for workflow run %s", workflow_run_id)
return
try:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
except Exception:
logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id)
return
generate_entity = resumption_context.get_generate_entity()
if not isinstance(generate_entity, AdvancedChatAppGenerateEntity):
logger.error(
"Resumption entity is not AdvancedChatAppGenerateEntity for workflow run %s (found %s)",
workflow_run_id,
type(generate_entity),
)
return
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = session.get(WorkflowRun, workflow_run_id)
if workflow_run is None:
logger.warning("Workflow run %s not found during resume", workflow_run_id)
return
workflow = session.get(Workflow, workflow_run.workflow_id)
if workflow is None:
logger.warning("Workflow %s not found during resume", workflow_run.workflow_id)
return
app_model = session.get(App, workflow_run.app_id)
if app_model is None:
logger.warning("App %s not found during resume", workflow_run.app_id)
return
if generate_entity.conversation_id is None:
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
return
conversation = session.get(Conversation, generate_entity.conversation_id)
if conversation is None:
logger.warning(
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
)
return
message = session.scalar(
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
)
if message is None:
logger.warning("Message not found for workflow run %s", workflow_run_id)
return
user = _resolve_user_for_run(session, workflow_run)
if user is None:
logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id)
return
workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
generator = AdvancedChatAppGenerator()
try:
generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
conversation=conversation,
message=message,
application_generate_entity=generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_runtime_state=graph_runtime_state,
pause_state_config=pause_config,
)
except Exception:
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
raise

View File

@ -5,6 +5,7 @@ These tasks handle workflow execution for different subscription tiers
with appropriate retry policies and error handling.
"""
import logging
from datetime import UTC, datetime
from typing import Any
@ -14,23 +15,37 @@ from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.app.layers.timeslice_layer import TimeSliceLayer
from core.app.layers.trigger_post_layer import TriggerPostLayer
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from repositories.factory import DifyAPIRepositoryFactory
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import WorkflowNotFoundError
from services.workflow.entities import (
TriggerData,
WorkflowResumeTaskData,
WorkflowTaskData,
)
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy
logger = logging.getLogger(__name__)
_TRIGGER_TO_RUN_SOURCE = {
AppTriggerType.TRIGGER_WEBHOOK: WorkflowRunTriggeredFrom.WEBHOOK,
AppTriggerType.TRIGGER_SCHEDULE: WorkflowRunTriggeredFrom.SCHEDULE,
AppTriggerType.TRIGGER_PLUGIN: WorkflowRunTriggeredFrom.PLUGIN,
}
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]):
@ -144,6 +159,11 @@ def _execute_workflow_common(
if trigger_data.workflow_id:
args["workflow_id"] = str(trigger_data.workflow_id)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
# Execute the workflow with the trigger type
generator.generate(
app_model=app_model,
@ -159,6 +179,7 @@ def _execute_workflow_common(
# TODO: Re-enable TimeSliceLayer after the HITL release.
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
],
pause_state_config=pause_config,
)
except Exception as e:
@ -176,6 +197,117 @@ def _execute_workflow_common(
session.commit()
@shared_task(name="resume_workflow_execution")
def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
"""Resume a paused workflow run via Celery."""
task_data = WorkflowResumeTaskData.model_validate(task_data_dict)
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
with session_factory() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
if not trigger_log:
logger.warning("Trigger log not found for resumption: %s", task_data.workflow_trigger_log_id)
return
pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id)
if pause_entity is None:
logger.warning("No pause state for workflow run %s", task_data.workflow_run_id)
return
try:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
except Exception as exc:
logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id)
raise exc
generate_entity = resumption_context.get_generate_entity()
if not isinstance(generate_entity, WorkflowAppGenerateEntity):
logger.error(
"Unsupported resumption entity for workflow run %s: %s",
task_data.workflow_run_id,
type(generate_entity),
)
return
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
if workflow is None:
raise WorkflowNotFoundError(f"Workflow not found: {trigger_log.workflow_id}")
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
if app_model is None:
raise WorkflowNotFoundError(f"App not found: {trigger_log.app_id}")
user = _get_user(session, trigger_log)
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
queue=trigger_log.queue_name,
schedule_strategy=AsyncWorkflowSystemStrategy,
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
)
cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity)
try:
trigger_type = AppTriggerType(trigger_log.trigger_type)
except ValueError:
trigger_type = AppTriggerType.UNKNOWN
triggered_from = _TRIGGER_TO_RUN_SOURCE.get(trigger_type, WorkflowRunTriggeredFrom.APP_RUN)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity)
trigger_log.status = WorkflowTriggerStatus.RUNNING
trigger_log_repo.update(trigger_log)
session.commit()
generator = WorkflowAppGenerator()
start_time = datetime.now(UTC)
try:
generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=[
TimeSliceLayer(cfs_plan_scheduler),
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
],
pause_state_config=pause_config,
)
except Exception as exc:
trigger_log.status = WorkflowTriggerStatus.FAILED
trigger_log.error = str(exc)
trigger_log.finished_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
raise
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
"""Compose user from trigger log"""
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))

View File

@ -0,0 +1,278 @@
import sys
import time
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
API_DIR = str(Path(__file__).resolve().parents[5])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)
import core.workflow.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.entities.commands import PauseCommand
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
if "core.ops.ops_trace_manager" not in sys.modules:
ops_stub = ModuleType("core.ops.ops_trace_manager")
class _StubTraceQueueManager:
def __init__(self, *_, **__):
pass
ops_stub.TraceQueueManager = _StubTraceQueueManager
sys.modules["core.ops.ops_trace_manager"] = ops_stub
class _StubToolNodeData(BaseNodeData):
pass
class _StubToolNode(Node):
node_type = NodeType.TOOL
def init_node_data(self, data):
self._node_data = _StubToolNodeData.model_validate(data)
def _get_error_strategy(self):
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self):
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"value": f"{self.id}-done"},
)
def _patch_tool_node(mocker):
from core.workflow.nodes import node_factory
custom_mapping = dict(node_factory.NODE_TYPE_CLASSES_MAPPING)
custom_versions = dict(custom_mapping[NodeType.TOOL])
custom_versions[node_factory.LATEST_VERSION] = _StubToolNode
custom_mapping[NodeType.TOOL] = custom_versions
mocker.patch("core.workflow.nodes.node_factory.NODE_TYPE_CLASSES_MAPPING", custom_mapping)
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={},
user_id="user",
user_from="account",
invoke_from="service-api",
call_depth=0,
)
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
start_node.init_node_data(start_data.model_dump())
tool_data = _StubToolNodeData(title="tool")
tool_a = _StubToolNode(
id="tool_a",
config={"id": "tool_a", "data": tool_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
tool_a.init_node_data(tool_data.model_dump())
tool_b = _StubToolNode(
id="tool_b",
config={"id": "tool_b", "data": tool_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
tool_b.init_node_data(tool_data.model_dump())
tool_c = _StubToolNode(
id="tool_c",
config={"id": "tool_c", "data": tool_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
tool_c.init_node_data(tool_data.model_dump())
end_data = EndNodeData(
title="end",
outputs=[VariableSelector(variable="result", value_selector=["tool_c", "value"])],
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
end_node.init_node_data(end_data.model_dump())
return (
Graph.new()
.add_root(start_node)
.add_node(tool_a)
.add_node(tool_b)
.add_node(tool_c)
.add_node(end_node)
.add_edge("tool_a", "tool_b")
.add_edge("tool_b", "tool_c")
.add_edge("tool_c", "end")
.build()
)
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
variable_pool.system_variables.workflow_execution_id = run_id
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
command_channel = InMemoryChannel()
graph = _build_graph(runtime_state)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=command_channel,
)
events: list[GraphEngineEvent] = []
for event in engine.run():
if isinstance(event, NodeRunSucceededEvent) and pause_on and event.node_id == pause_on:
command_channel.send_command(PauseCommand(reason="test pause"))
engine._command_processor.process_commands() # type: ignore[attr-defined]
events.append(event)
return events
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
def test_workflow_app_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("paused-run")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = wf_app_gen_module.WorkflowAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
graph_runtime_state=resumed_state,
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs
def test_advanced_chat_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("adv-baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("adv-paused")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = adv_app_gen_module.AdvancedChatAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
conversation=SimpleNamespace(id="conv"),
message=SimpleNamespace(id="msg"),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
graph_runtime_state=resumed_state,
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs

View File

@ -61,8 +61,8 @@ class ConcurrentPublisher:
messages.append(message)
if self.delay > 0:
time.sleep(self.delay)
except Exception as e:
_logger.error("Publisher %s error: %s", thread_id, e)
except Exception:
_logger.exception("Pubmsg=lisher %s", thread_id)
with self._lock:
self.published_messages.append(messages)
@ -308,8 +308,8 @@ def measure_throughput(
try:
operation()
count += 1
except Exception as e:
_logger.error("Operation failed: %s", e)
except Exception:
_logger.exception("Operation failed")
break
elapsed = time.time() - start_time

View File

@ -1,4 +1,26 @@
import sys
from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock
if "core.ops.ops_trace_manager" not in sys.modules:
stub_module = ModuleType("core.ops.ops_trace_manager")
class _StubTraceQueueManager:
def __init__(self, *_, **__):
pass
stub_module.TraceQueueManager = _StubTraceQueueManager
sys.modules["core.ops.ops_trace_manager"] = stub_module
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from tests.unit_tests.core.workflow.graph_engine.test_pause_resume_state import (
_build_pausing_graph,
_build_runtime_state,
_node_successes,
_PausingNode,
_PausingNodeData,
_run_graph,
)
def test_should_prepare_user_inputs_defaults_to_true():
@ -17,3 +39,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
def test_resume_delegates_to_generate(mocker):
generator = WorkflowAppGenerator()
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")
application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger")
runtime_state = MagicMock(name="runtime-state")
pause_config = MagicMock(name="pause-config")
result = generator.resume(
app_model=MagicMock(),
workflow=MagicMock(),
user=MagicMock(),
application_generate_entity=application_generate_entity,
graph_runtime_state=runtime_state,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
graph_engine_layers=("layer",),
pause_state_config=pause_config,
variable_loader=MagicMock(),
)
assert result == "ok"
mock_generate.assert_called_once()
kwargs = mock_generate.call_args.kwargs
assert kwargs["graph_runtime_state"] is runtime_state
assert kwargs["pause_state_config"] is pause_config
assert kwargs["streaming"] is False
assert kwargs["invoke_from"] == "debugger"
def test_generate_appends_pause_layer_and_forwards_state(mocker):
generator = WorkflowAppGenerator()
mock_queue_manager = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager)
fake_current_app = MagicMock()
fake_current_app._get_current_object.return_value = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app)
mocker.patch(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
return_value="converted",
)
mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response")
mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock())
pause_layer = MagicMock(name="pause-layer")
mocker.patch(
"core.app.apps.workflow.app_generator.PauseStatePersistenceLayer",
return_value=pause_layer,
)
dummy_session = MagicMock()
dummy_session.close = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session)
worker_kwargs: dict[str, object] = {}
class DummyThread:
def __init__(self, target, kwargs):
worker_kwargs["target"] = target
worker_kwargs["kwargs"] = kwargs
def start(self):
return None
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread)
app_model = SimpleNamespace(mode="workflow")
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf")
application_generate_entity = SimpleNamespace(
task_id="task",
user_id="user",
invoke_from="service-api",
app_config=app_config,
files=[],
stream=True,
workflow_execution_id="run",
)
graph_runtime_state = MagicMock()
result = generator._generate(
app_model=app_model,
workflow=MagicMock(),
user=MagicMock(),
application_generate_entity=application_generate_entity,
invoke_from="service-api",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
streaming=True,
graph_engine_layers=("base-layer",),
graph_runtime_state=graph_runtime_state,
pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"),
)
assert result == "converted"
assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer)
assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state
def test_resume_path_runs_worker_with_runtime_state(mocker):
generator = WorkflowAppGenerator()
runtime_state = MagicMock(name="runtime-state")
pause_layer = MagicMock(name="pause-layer")
mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer)
queue_manager = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager)
mocker.patch.object(generator, "_handle_response", return_value="raw-response")
mocker.patch(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
side_effect=lambda response, invoke_from: response,
)
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
workflow = SimpleNamespace(
id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1"
)
end_user = SimpleNamespace(session_id="end-user-session")
app_record = SimpleNamespace(id="app")
session = MagicMock()
session.__enter__.return_value = session
session.__exit__.return_value = False
session.scalar.side_effect = [workflow, end_user, app_record]
mocker.patch("core.app.apps.workflow.app_generator.Session", return_value=session)
runner_instance = MagicMock()
def runner_ctor(**kwargs):
assert kwargs["graph_runtime_state"] is runtime_state
return runner_instance
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor)
class ImmediateThread:
def __init__(self, target, kwargs):
target(**kwargs)
def start(self):
return None
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread)
mocker.patch(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=MagicMock(),
)
pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner")
app_model = SimpleNamespace(mode="workflow")
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow")
application_generate_entity = SimpleNamespace(
task_id="task",
user_id="user",
invoke_from="service-api",
app_config=app_config,
files=[],
stream=True,
workflow_execution_id="run",
trace_manager=MagicMock(),
)
result = generator.resume(
app_model=app_model,
workflow=workflow,
user=MagicMock(),
application_generate_entity=application_generate_entity,
graph_runtime_state=runtime_state,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
pause_state_config=pause_config,
)
assert result == "raw-response"
runner_instance.run.assert_called_once()
queue_manager.graph_runtime_state = runtime_state

View File

@ -1,317 +0,0 @@
"""
Tests for HumanInputForm domain model and repository.
"""
import json
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from core.repositories.sqlalchemy_human_input_form_repository import SQLAlchemyHumanInputFormRepository
from core.workflow.entities.human_input_form import HumanInputForm, HumanInputFormStatus
class TestHumanInputForm:
"""Test cases for HumanInputForm domain model."""
def test_create_form(self):
"""Test creating a new form."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
assert form.id_ == "test-form-id"
assert form.workflow_run_id == "test-workflow-run"
assert form.status == HumanInputFormStatus.WAITING
assert form.can_be_submitted
assert not form.is_submitted
assert not form.is_expired
assert form.is_waiting
def test_submit_form(self):
"""Test submitting a form."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.submit(
data={"field1": "value1"},
action="submit",
submission_user_id="user123",
)
assert form.is_submitted
assert not form.can_be_submitted
assert form.status == HumanInputFormStatus.SUBMITTED
assert form.submission is not None
assert form.submission.data == {"field1": "value1"}
assert form.submission.action == "submit"
assert form.submission.submission_user_id == "user123"
def test_submit_form_invalid_action(self):
"""Test submitting a form with invalid action."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
with pytest.raises(ValueError, match="Invalid action: invalid_action"):
form.submit(data={}, action="invalid_action")
def test_submit_expired_form(self):
"""Test submitting an expired form should fail."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.expire()
with pytest.raises(ValueError, match="Form cannot be submitted in status: expired"):
form.submit(data={}, action="submit")
def test_expire_form(self):
"""Test expiring a form."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.expire()
assert form.is_expired
assert not form.can_be_submitted
assert form.status == HumanInputFormStatus.EXPIRED
def test_expire_already_submitted_form(self):
"""Test expiring an already submitted form should fail."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.submit(data={}, action="submit")
with pytest.raises(ValueError, match="Form cannot be expired in status: submitted"):
form.expire()
def test_get_form_definition_for_display(self):
"""Test getting form definition for display."""
form_definition = {
"inputs": [{"type": "text", "name": "field1"}],
"user_actions": [{"id": "submit", "title": "Submit"}],
}
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition=form_definition,
rendered_content="<form>Test form</form>",
)
result = form.get_form_definition_for_display()
assert result["form_content"] == "<form>Test form</form>"
assert result["inputs"] == form_definition["inputs"]
assert result["user_actions"] == form_definition["user_actions"]
assert "site" not in result
def test_get_form_definition_for_display_with_site_info(self):
"""Test getting form definition for display with site info."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": []},
rendered_content="<form>Test form</form>",
)
result = form.get_form_definition_for_display(include_site_info=True)
assert "site" in result
assert result["site"]["title"] == "Workflow Form"
def test_get_form_definition_expired_form(self):
"""Test getting form definition for expired form should fail."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": []},
rendered_content="<form>Test form</form>",
)
form.expire()
with pytest.raises(ValueError, match="Form has expired"):
form.get_form_definition_for_display()
def test_get_form_definition_submitted_form(self):
"""Test getting form definition for submitted form should fail."""
form = HumanInputForm.create(
id_="test-form-id",
workflow_run_id="test-workflow-run",
form_definition={"inputs": [], "user_actions": [{"id": "submit", "title": "Submit"}]},
rendered_content="<form>Test form</form>",
)
form.submit(data={}, action="submit")
with pytest.raises(ValueError, match="Form has already been submitted"):
form.get_form_definition_for_display()
class TestSQLAlchemyHumanInputFormRepository:
"""Test cases for SQLAlchemyHumanInputFormRepository."""
@pytest.fixture
def mock_session_factory(self):
"""Create a mock session factory."""
session = MagicMock()
session_factory = MagicMock()
session_factory.return_value.__enter__.return_value = session
session_factory.return_value.__exit__.return_value = None
return session_factory
@pytest.fixture
def mock_user(self):
"""Create a mock user."""
user = MagicMock()
user.current_tenant_id = "test-tenant-id"
user.id = "test-user-id"
return user
@pytest.fixture
def repository(self, mock_session_factory, mock_user):
"""Create a repository instance."""
return SQLAlchemyHumanInputFormRepository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
)
def test_to_domain_model(self, repository):
"""Test converting DB model to domain model."""
from models.human_input import (
HumanInputForm as DBForm,
)
from models.human_input import (
HumanInputFormStatus as DBStatus,
)
from models.human_input import (
HumanInputSubmissionType as DBSubmissionType,
)
db_form = DBForm()
db_form.id = "test-id"
db_form.workflow_run_id = "test-workflow"
db_form.form_definition = json.dumps({"inputs": [], "user_actions": []})
db_form.rendered_content = "<form>Test</form>"
db_form.status = DBStatus.WAITING
db_form.web_app_token = "test-token"
db_form.created_at = datetime.utcnow()
db_form.submitted_data = json.dumps({"field": "value"})
db_form.submitted_at = datetime.utcnow()
db_form.submission_type = DBSubmissionType.web_form
db_form.submission_user_id = "user123"
domain_form = repository._to_domain_model(db_form)
assert domain_form.id_ == "test-id"
assert domain_form.workflow_run_id == "test-workflow"
assert domain_form.form_definition == {"inputs": [], "user_actions": []}
assert domain_form.rendered_content == "<form>Test</form>"
assert domain_form.status == HumanInputFormStatus.WAITING
assert domain_form.web_app_token == "test-token"
assert domain_form.submission is not None
assert domain_form.submission.data == {"field": "value"}
assert domain_form.submission.submission_user_id == "user123"
def test_to_db_model(self, repository):
"""Test converting domain model to DB model."""
from models.human_input import (
HumanInputFormStatus as DBStatus,
)
domain_form = HumanInputForm.create(
id_="test-id",
workflow_run_id="test-workflow",
form_definition={"inputs": [], "user_actions": []},
rendered_content="<form>Test</form>",
web_app_token="test-token",
)
db_form = repository._to_db_model(domain_form)
assert db_form.id == "test-id"
assert db_form.tenant_id == "test-tenant-id"
assert db_form.app_id == "test-app-id"
assert db_form.workflow_run_id == "test-workflow"
assert json.loads(db_form.form_definition) == {"inputs": [], "user_actions": []}
assert db_form.rendered_content == "<form>Test</form>"
assert db_form.status == DBStatus.WAITING
assert db_form.web_app_token == "test-token"
def test_save(self, repository, mock_session_factory):
"""Test saving a form."""
session = mock_session_factory.return_value.__enter__.return_value
domain_form = HumanInputForm.create(
id_="test-id",
workflow_run_id="test-workflow",
form_definition={"inputs": []},
rendered_content="<form>Test</form>",
)
repository.save(domain_form)
session.merge.assert_called_once()
session.commit.assert_called_once()
def test_get_by_id(self, repository, mock_session_factory):
"""Test getting a form by ID."""
session = mock_session_factory.return_value.__enter__.return_value
mock_db_form = MagicMock()
mock_db_form.id = "test-id"
session.scalar.return_value = mock_db_form
with patch.object(repository, "_to_domain_model") as mock_convert:
domain_form = HumanInputForm.create(
id_="test-id",
workflow_run_id="test-workflow",
form_definition={"inputs": []},
rendered_content="<form>Test</form>",
)
mock_convert.return_value = domain_form
result = repository.get_by_id("test-id")
assert result == domain_form
session.scalar.assert_called_once()
mock_convert.assert_called_once_with(mock_db_form)
def test_get_by_id_not_found(self, repository, mock_session_factory):
"""Test getting a non-existent form by ID."""
session = mock_session_factory.return_value.__enter__.return_value
session.scalar.return_value = None
with pytest.raises(ValueError, match="Human input form not found: test-id"):
repository.get_by_id("test-id")
def test_mark_expired_forms(self, repository, mock_session_factory):
"""Test marking expired forms."""
session = mock_session_factory.return_value.__enter__.return_value
mock_forms = [MagicMock(), MagicMock(), MagicMock()]
session.scalars.return_value.all.return_value = mock_forms
result = repository.mark_expired_forms(expiry_hours=24)
assert result == 3
for form in mock_forms:
assert hasattr(form, "status")
session.commit.assert_called_once()

View File

@ -2,16 +2,27 @@
from __future__ import annotations
import dataclasses
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_reposotiry import (
HumanInputFormReadRepository,
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
_WorkspaceMemberInfo,
)
from core.workflow.nodes.human_input.entities import ExternalRecipient, MemberRecipient
from core.workflow.nodes.human_input.entities import (
ExternalRecipient,
FormDefinition,
MemberRecipient,
TimeoutUnit,
UserAction,
)
from libs.datetime_utils import naive_utc_now
from models.human_input import (
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
@ -41,6 +52,23 @@ def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleName
return created
@pytest.fixture(autouse=True)
def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None:
"""Avoid SQLAlchemy mapper configuration in tests using fake sessions."""
class _FakeSelect:
def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
return self
def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
return self
monkeypatch.setattr(
"core.repositories.human_input_reposotiry.selectinload", lambda *args, **kwargs: "_loader_option"
)
monkeypatch.setattr("core.repositories.human_input_reposotiry.select", lambda *args, **kwargs: _FakeSelect())
class TestHumanInputFormRepositoryImplHelpers:
def test_create_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
@ -125,3 +153,201 @@ class TestHumanInputFormRepositoryImplHelpers:
assert len(recipients) == 2
emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients}
assert emails == {"member1@example.com", "member2@example.com"}
def _make_form_definition() -> str:
return FormDefinition(
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
rendered_content="<p>hello</p>",
timeout=1,
timeout_unit=TimeoutUnit.HOUR,
).model_dump_json()
@dataclasses.dataclass
class _DummyForm:
id: str
workflow_run_id: str
node_id: str
tenant_id: str
form_definition: str
rendered_content: str
expiration_time: datetime
selected_action_id: str | None = None
submitted_data: str | None = None
submitted_at: datetime | None = None
submission_user_id: str | None = None
submission_end_user_id: str | None = None
completed_by_recipient_id: str | None = None
@dataclasses.dataclass
class _DummyRecipient:
id: str
form_id: str
recipient_type: RecipientType
access_token: str
form: _DummyForm | None = None
class _FakeScalarResult:
def __init__(self, obj):
self._obj = obj
def first(self):
return self._obj
class _FakeSession:
def __init__(
self,
*,
scalars_result=None,
forms: dict[str, _DummyForm] | None = None,
recipients: dict[str, _DummyRecipient] | None = None,
):
self._scalars_result = scalars_result
self.forms = forms or {}
self.recipients = recipients or {}
def scalars(self, _query):
return _FakeScalarResult(self._scalars_result)
def get(self, model_cls, obj_id): # type: ignore[no-untyped-def]
if getattr(model_cls, "__name__", None) == "HumanInputForm":
return self.forms.get(obj_id)
if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient":
return self.recipients.get(obj_id)
return None
def add(self, _obj):
return None
def flush(self):
return None
def refresh(self, _obj):
return None
def begin(self):
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def _session_factory(session: _FakeSession):
class _SessionContext:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return None
def _factory(*_args, **_kwargs):
return _SessionContext()
return _factory
class TestHumanInputFormReadRepository:
def test_get_by_token_returns_record(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.WEBAPP,
access_token="token-123",
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormReadRepository(_session_factory(session))
record = repo.get_by_token("token-123")
assert record is not None
assert record.form_id == form.id
assert record.recipient_type == RecipientType.WEBAPP
assert record.submitted is False
def test_get_by_form_id_and_recipient_type_uses_recipient(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.WEBAPP,
access_token="token-123",
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormReadRepository(_session_factory(session))
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.WEBAPP)
assert record is not None
assert record.recipient_id == recipient.id
assert record.access_token == recipient.access_token
def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch):
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_reposotiry.naive_utc_now", lambda: fixed_now)
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=fixed_now,
)
recipient = _DummyRecipient(
id="recipient-1",
form_id="form-1",
recipient_type=RecipientType.WEBAPP,
access_token="token-123",
)
session = _FakeSession(
forms={form.id: form},
recipients={recipient.id: recipient},
)
repo = HumanInputFormReadRepository(_session_factory(session))
record: HumanInputFormRecord = repo.mark_submitted(
form_id=form.id,
recipient_id=recipient.id,
selected_action_id="approve",
form_data={"field": "value"},
submission_user_id="user-1",
submission_end_user_id="end-user-1",
)
assert form.selected_action_id == "approve"
assert form.completed_by_recipient_id == recipient.id
assert form.submission_user_id == "user-1"
assert form.submission_end_user_id == "end-user-1"
assert form.submitted_at == fixed_now
assert record.submitted is True
assert record.selected_action_id == "approve"
assert record.submitted_data == {"field": "value"}

View File

@ -28,10 +28,11 @@ class TestPauseReasonDiscriminator:
{
"reason": {
"TYPE": "human_input_required",
"human_input_form_id": "form_id",
"form_id": "form_id",
"form_content": "form_content",
},
},
HumanInputRequired(human_input_form_id="form_id"),
HumanInputRequired(form_id="form_id", form_content="form_content"),
id="HumanInputRequired",
),
pytest.param(
@ -56,7 +57,7 @@ class TestPauseReasonDiscriminator:
@pytest.mark.parametrize(
"reason",
[
HumanInputRequired(human_input_form_id="form_id"),
HumanInputRequired(form_id="form_id", form_content="form_content"),
SchedulingPause(message="Hold on"),
],
ids=lambda x: type(x).__name__,

View File

@ -764,203 +764,3 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
assert partial_event.outputs.get("answer") == "fallback response"
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)
def test_suspend_and_resume():
graph_config = {
"edges": [
{
"data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"},
"id": "1753041723554-source-1753041730748-target",
"source": "1753041723554",
"sourceHandle": "source",
"target": "1753041730748",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"},
"id": "1753041730748-true-answer-target",
"source": "1753041730748",
"sourceHandle": "true",
"target": "answer",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {
"isInIteration": False,
"isInLoop": False,
"sourceType": "if-else",
"targetType": "answer",
},
"id": "1753041730748-false-1753041952799-target",
"source": "1753041730748",
"sourceHandle": "false",
"target": "1753041952799",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
],
"nodes": [
{
"data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []},
"height": 54,
"id": "1753041723554",
"position": {"x": 32, "y": 282},
"positionAbsolute": {"x": 32, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d",
"value": "a",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"desc": "",
"selected": False,
"title": "IF/ELSE",
"type": "if-else",
},
"height": 126,
"id": "1753041730748",
"position": {"x": 368, "y": 282},
"positionAbsolute": {"x": 368, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "A",
"desc": "",
"selected": False,
"title": "Answer A",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "answer",
"position": {"x": 746, "y": 282},
"positionAbsolute": {"x": 746, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "Else",
"desc": "",
"selected": False,
"title": "Answer Else",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "1753041952799",
"position": {"x": 746, "y": 426},
"positionAbsolute": {"x": 746, "y": 426},
"selected": True,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
],
"viewport": {"x": -420, "y": -76.5, "zoom": 1},
}
graph = Graph.init(graph_config)
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="hello",
conversation_id="abababa",
),
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
max_execution_steps=500,
max_execution_time=1200,
)
_IF_ELSE_NODE_ID = "1753041730748"
def command_source(params: CommandParams) -> CommandTypes:
# requires the engine to suspend before the execution
# of If-Else node.
if params.next_node.node_id == _IF_ELSE_NODE_ID:
return SuspendCommand()
else:
return ContinueCommand()
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
command_source=command_source,
)
events = list(graph_engine.run())
last_event = events[-1]
assert isinstance(last_event, GraphRunSuspendedEvent)
assert last_event.current_node_id == _IF_ELSE_NODE_ID
state = graph_engine.save()
assert state != ""
engine2 = GraphEngine.resume(
state=state,
graph=graph,
)
events = list(engine2.run())
assert isinstance(events[-1], GraphRunSucceededEvent)
node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)]
assert node_run_succeeded_events
start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"]
assert not start_events
ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID]
assert ifelse_succeeded_events
answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"]
assert answer_else_events
assert answer_else_events[0].route_node_state.node_run_result.outputs == {
"answer": "Else",
"files": ArrayFileSegment(value=[]),
}
answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"]
assert not answer_a_events

View File

@ -17,8 +17,8 @@ from core.workflow.graph_events import (
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,

View File

@ -16,8 +16,8 @@ from core.workflow.graph_events import (
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,

View File

@ -0,0 +1,185 @@
import time
from collections.abc import Generator, Mapping
from typing import Any
import core.workflow.nodes.human_input.entities # noqa: F401
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeEventBase, NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
class _PausingNodeData(BaseNodeData):
pass
class _PausingNode(Node):
node_type = NodeType.TOOL
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = _PausingNodeData.model_validate(data)
def _get_error_strategy(self):
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@staticmethod
def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]:
yield event
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
resumed_flag = self.graph_runtime_state.variable_pool.get((self.id, "resumed"))
if resumed_flag is None:
# mark as resumed and request pause
self.graph_runtime_state.variable_pool.add((self.id, "resumed"), True)
return self._pause_generator(PauseRequestedEvent(reason=SchedulingPause(message="test pause")))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"value": "completed"},
)
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="service-api",
call_depth=0,
)
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
start_node.init_node_data(start_data.model_dump())
pause_data = _PausingNodeData(title="pausing")
pause_node = _PausingNode(
id="pausing",
config={"id": "pausing", "data": pause_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
pause_node.init_node_data(pause_data.model_dump())
end_data = EndNodeData(
title="end",
outputs=[
VariableSelector(variable="result", value_selector=["pausing", "value"]),
],
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
end_node.init_node_data(end_data.model_dump())
return Graph.new().add_root(start_node).add_node(pause_node).add_node(end_node).build()
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]:
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
return list(engine.run())
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)]
def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any:
segment = variable_pool.get(selector)
assert segment is not None
return getattr(segment, "value", segment)
def test_engine_resume_restores_state_and_completion():
# Baseline run without pausing
baseline_state = _build_runtime_state()
baseline_graph = _build_pausing_graph(baseline_state)
baseline_state.variable_pool.add(("pausing", "resumed"), True)
baseline_events = _run_graph(baseline_graph, baseline_state)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_success_nodes = _node_successes(baseline_events)
# Run with pause
paused_state = _build_runtime_state()
paused_graph = _build_pausing_graph(paused_state)
paused_events = _run_graph(paused_graph, paused_state)
assert isinstance(paused_events[-1], GraphRunPausedEvent)
snapshot = paused_state.dumps()
# Resume from snapshot
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resumed_graph = _build_pausing_graph(resumed_state)
resumed_events = _run_graph(resumed_graph, resumed_state)
assert isinstance(resumed_events[-1], GraphRunSucceededEvent)
combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events)
assert combined_success_nodes == baseline_success_nodes
assert baseline_state.outputs == resumed_state.outputs
assert _segment_value(baseline_state.variable_pool, ("pausing", "resumed")) == _segment_value(
resumed_state.variable_pool, ("pausing", "resumed")
)
assert _segment_value(baseline_state.variable_pool, ("pausing", "value")) == _segment_value(
resumed_state.variable_pool, ("pausing", "value")
)
assert baseline_state.graph_execution.completed
assert resumed_state.graph_execution.completed

View File

@ -1,283 +0,0 @@
"""
Unit tests for human input node implementation.
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.human_input import HumanInputNode, HumanInputNodeData
class TestHumanInputNode:
"""Test HumanInputNode implementation."""
@pytest.fixture
def mock_graph_init_params(self):
"""Create mock graph initialization parameters."""
mock_params = Mock()
mock_params.tenant_id = "tenant-123"
mock_params.app_id = "app-456"
mock_params.user_id = "user-789"
mock_params.user_from = "web"
mock_params.invoke_from = "web_app"
mock_params.call_depth = 0
return mock_params
@pytest.fixture
def mock_graph(self):
"""Create mock graph."""
return Mock()
@pytest.fixture
def mock_graph_runtime_state(self):
"""Create mock graph runtime state."""
return Mock()
@pytest.fixture
def sample_node_config(self):
"""Create sample node configuration."""
return {
"id": "human_input_123",
"data": {
"title": "User Confirmation",
"desc": "Please confirm the action",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}",
"inputs": [
{
"type": "text-input",
"output_variable_name": "confirmation",
"placeholder": {"type": "constant", "value": "Type 'yes' to confirm"},
}
],
"user_actions": [
{"id": "confirm", "title": "Confirm", "button_style": "primary"},
{"id": "cancel", "title": "Cancel", "button_style": "default"},
],
"timeout": 24,
"timeout_unit": "hour",
},
}
@pytest.fixture
def human_input_node(self, sample_node_config, mock_graph_init_params, mock_graph, mock_graph_runtime_state):
"""Create HumanInputNode instance."""
node = HumanInputNode(
id="node_123",
config=sample_node_config,
graph_init_params=mock_graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
return node
def test_node_initialization(self, human_input_node):
"""Test node initialization."""
assert human_input_node.node_id == "human_input_123"
assert human_input_node.tenant_id == "tenant-123"
assert human_input_node.app_id == "app-456"
assert isinstance(human_input_node.node_data, HumanInputNodeData)
assert human_input_node.node_data.title == "User Confirmation"
def test_node_type_and_version(self, human_input_node):
"""Test node type and version."""
assert human_input_node.type_.value == "human_input"
assert human_input_node.version() == "1"
def test_node_properties(self, human_input_node):
"""Test node properties access."""
assert human_input_node.title == "User Confirmation"
assert human_input_node.description == "Please confirm the action"
assert human_input_node.error_strategy is None
assert human_input_node.retry_config.retry_enabled is False
@patch("uuid.uuid4")
def test_node_run_success(self, mock_uuid, human_input_node):
"""Test successful node execution."""
# Setup mocks
mock_form_id = uuid.UUID("12345678-1234-5678-9abc-123456789012")
mock_token = uuid.UUID("87654321-4321-8765-cba9-876543210987")
mock_uuid.side_effect = [mock_form_id, mock_token]
# Execute the node
result = human_input_node._run()
# Verify result
assert result.status == WorkflowNodeExecutionStatus.RUNNING
assert result.metadata["suspended"] is True
assert result.metadata["form_id"] == str(mock_form_id)
assert result.metadata["web_app_form_token"] == str(mock_token).replace("-", "")
# Verify event data in metadata
human_input_event = result.metadata["human_input_event"]
assert human_input_event["form_id"] == str(mock_form_id)
assert human_input_event["node_id"] == "human_input_123"
assert human_input_event["form_content"] == "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}"
assert len(human_input_event["inputs"]) == 1
suspended_event = result.metadata["suspended_event"]
assert suspended_event["suspended_at_node_ids"] == ["human_input_123"]
def test_node_run_without_webapp_delivery(self, human_input_node):
"""Test node execution without webapp delivery method."""
# Modify node data to disable webapp delivery
human_input_node.node_data.delivery_methods[0].enabled = False
result = human_input_node._run()
# Should still work, but without web app token
assert result.status == WorkflowNodeExecutionStatus.RUNNING
assert result.metadata["web_app_form_token"] is None
def test_resume_from_human_input_success(self, human_input_node):
"""Test successful resume from human input."""
form_submission_data = {"inputs": {"confirmation": "yes"}, "action": "confirm"}
result = human_input_node.resume_from_human_input(form_submission_data)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["confirmation"] == "yes"
assert result.outputs["_action"] == "confirm"
assert result.metadata["form_submitted"] is True
assert result.metadata["submitted_action"] == "confirm"
def test_resume_from_human_input_partial_inputs(self, human_input_node):
"""Test resume with partial inputs."""
form_submission_data = {
"inputs": {}, # Empty inputs
"action": "cancel",
}
result = human_input_node.resume_from_human_input(form_submission_data)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "confirmation" not in result.outputs # Field not provided
assert result.outputs["_action"] == "cancel"
def test_resume_from_human_input_missing_data(self, human_input_node):
"""Test resume with missing submission data."""
form_submission_data = {} # Missing required fields
result = human_input_node.resume_from_human_input(form_submission_data)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["_action"] == "" # Default empty action
def test_get_default_config(self):
"""Test getting default configuration."""
config = HumanInputNode.get_default_config()
assert config["type"] == "human_input"
assert "config" in config
config_data = config["config"]
assert len(config_data["delivery_methods"]) == 1
assert config_data["delivery_methods"][0]["type"] == "webapp"
assert config_data["delivery_methods"][0]["enabled"] is True
assert config_data["form_content"] == "# Human Input\n\nPlease provide your input:\n\n{{#$output.input#}}"
assert len(config_data["inputs"]) == 1
assert config_data["inputs"][0]["output_variable_name"] == "input"
assert len(config_data["user_actions"]) == 1
assert config_data["user_actions"][0]["id"] == "submit"
assert config_data["timeout"] == 24
assert config_data["timeout_unit"] == "hour"
def test_process_form_content(self, human_input_node):
"""Test form content processing."""
# This is a placeholder test since the actual variable substitution
# logic is marked as TODO in the implementation
processed_content = human_input_node._process_form_content()
# For now, should return the raw content
expected_content = "# Confirmation\n\nPlease confirm: {{#$output.confirmation#}}"
assert processed_content == expected_content
def test_extract_variable_selector_mapping(self):
"""Test variable selector extraction."""
graph_config = {}
node_data = {
"form_content": "Hello {{#node_123.output#}}",
"inputs": [
{
"type": "text-input",
"output_variable_name": "test",
"placeholder": {"type": "variable", "selector": ["node_456", "var_name"]},
}
],
}
# This is a placeholder test since the actual extraction logic
# is marked as TODO in the implementation
mapping = HumanInputNode._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id="test_node", node_data=node_data
)
# For now, should return empty dict
assert mapping == {}
class TestHumanInputNodeValidation:
"""Test validation scenarios for HumanInputNode."""
def test_node_with_invalid_config(self):
"""Test node creation with invalid configuration."""
invalid_config = {
"id": "test_node",
"data": {
"title": "Test",
"delivery_methods": [
{
"type": "invalid_type", # Invalid delivery method type
"enabled": True,
"config": {},
}
],
},
}
mock_params = Mock()
mock_params.tenant_id = "tenant-123"
mock_params.app_id = "app-456"
mock_params.user_id = "user-789"
mock_params.user_from = "web"
mock_params.invoke_from = "web_app"
mock_params.call_depth = 0
with pytest.raises(ValueError):
HumanInputNode(
id="node_123",
config=invalid_config,
graph_init_params=mock_params,
graph=Mock(),
graph_runtime_state=Mock(),
)
def test_node_with_missing_node_id(self):
"""Test node creation with missing node ID in config."""
invalid_config = {
# Missing "id" field
"data": {"title": "Test"}
}
mock_params = Mock()
mock_params.tenant_id = "tenant-123"
mock_params.app_id = "app-456"
mock_params.user_id = "user-789"
mock_params.user_from = "web"
mock_params.invoke_from = "web_app"
mock_params.call_depth = 0
with pytest.raises(ValueError, match="Node ID is required"):
HumanInputNode(
id="node_123",
config=invalid_config,
graph_init_params=mock_params,
graph=Mock(),
graph_runtime_state=Mock(),
)

View File

@ -1 +1 @@
# Unit tests for human input library
# Treat this directory as a package so support modules can be imported relatively.

View File

@ -0,0 +1,248 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Optional
from core.workflow.nodes.human_input.entities import FormInput, TimeoutUnit
# Exceptions
class HumanInputError(Exception):
error_code: str = "unknown"
def __init__(self, message: str = "", error_code: str | None = None):
super().__init__(message)
self.message = message or self.__class__.__name__
if error_code:
self.error_code = error_code
class FormNotFoundError(HumanInputError):
error_code = "form_not_found"
class FormExpiredError(HumanInputError):
error_code = "human_input_form_expired"
class FormAlreadySubmittedError(HumanInputError):
error_code = "human_input_form_submitted"
class InvalidFormDataError(HumanInputError):
error_code = "invalid_form_data"
# Models
@dataclass
class HumanInputForm:
form_id: str
workflow_run_id: str
node_id: str
tenant_id: str
app_id: str | None
form_content: str
inputs: list[FormInput]
user_actions: list[dict[str, Any]]
timeout: int
timeout_unit: TimeoutUnit
web_app_form_token: str | None = None
created_at: datetime = field(default_factory=datetime.utcnow)
expires_at: datetime | None = None
submitted_at: datetime | None = None
submitted_data: dict[str, Any] | None = None
submitted_action: str | None = None
def __post_init__(self) -> None:
if self.expires_at is None:
self.calculate_expiration()
@property
def is_expired(self) -> bool:
return self.expires_at is not None and datetime.utcnow() > self.expires_at
@property
def is_submitted(self) -> bool:
return self.submitted_at is not None
def mark_submitted(self, inputs: dict[str, Any], action: str) -> None:
self.submitted_data = inputs
self.submitted_action = action
self.submitted_at = datetime.utcnow()
def submit(self, inputs: dict[str, Any], action: str) -> None:
self.mark_submitted(inputs, action)
def calculate_expiration(self) -> None:
start = self.created_at
if self.timeout_unit == TimeoutUnit.HOUR:
self.expires_at = start + timedelta(hours=self.timeout)
elif self.timeout_unit == TimeoutUnit.DAY:
self.expires_at = start + timedelta(days=self.timeout)
else:
raise ValueError(f"Unsupported timeout unit {self.timeout_unit}")
def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]:
inputs_response = [
{
"type": form_input.type.name.lower().replace("_", "-"),
"output_variable_name": form_input.output_variable_name,
}
for form_input in self.inputs
]
response = {
"form_content": self.form_content,
"inputs": inputs_response,
"user_actions": self.user_actions,
}
if include_site_info:
response["site"] = {"app_id": self.app_id, "title": "Workflow Form"}
return response
@dataclass
class FormSubmissionData:
form_id: str
inputs: dict[str, Any]
action: str
submitted_at: datetime = field(default_factory=datetime.utcnow)
@classmethod
def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore
return cls(form_id=form_id, inputs=request.inputs, action=request.action)
@dataclass
class FormSubmissionRequest:
inputs: dict[str, Any]
action: str
# Repository
class InMemoryFormRepository:
"""
Simple in-memory repository used by unit tests.
"""
def __init__(self):
self._forms: dict[str, HumanInputForm] = {}
@property
def forms(self) -> dict[str, HumanInputForm]:
return self._forms
def save(self, form: HumanInputForm) -> None:
self._forms[form.form_id] = form
def get_by_id(self, form_id: str) -> Optional[HumanInputForm]:
return self._forms.get(form_id)
def get_by_token(self, token: str) -> Optional[HumanInputForm]:
for form in self._forms.values():
if form.web_app_form_token == token:
return form
return None
def delete(self, form_id: str) -> None:
self._forms.pop(form_id, None)
# Service
class FormService:
"""Service layer for managing human input forms in tests."""
def __init__(self, repository: InMemoryFormRepository):
self.repository = repository
def create_form(
self,
*,
form_id: str,
workflow_run_id: str,
node_id: str,
tenant_id: str,
app_id: str | None,
form_content: str,
inputs,
user_actions,
timeout: int,
timeout_unit: TimeoutUnit,
web_app_form_token: str | None = None,
) -> HumanInputForm:
form = HumanInputForm(
form_id=form_id,
workflow_run_id=workflow_run_id,
node_id=node_id,
tenant_id=tenant_id,
app_id=app_id,
form_content=form_content,
inputs=list(inputs),
user_actions=[{"id": action.id, "title": action.title} for action in user_actions],
timeout=timeout,
timeout_unit=timeout_unit,
web_app_form_token=web_app_form_token,
)
form.calculate_expiration()
self.repository.save(form)
return form
def get_form_by_id(self, form_id: str) -> HumanInputForm:
form = self.repository.get_by_id(form_id)
if form is None:
raise FormNotFoundError()
return form
def get_form_by_token(self, token: str) -> HumanInputForm:
form = self.repository.get_by_token(token)
if form is None:
raise FormNotFoundError()
return form
def get_form_definition(self, form_id: str, *, is_token: bool) -> dict:
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
if form.is_expired:
raise FormExpiredError()
if form.is_submitted:
raise FormAlreadySubmittedError()
definition = {
"form_content": form.form_content,
"inputs": form.inputs,
"user_actions": form.user_actions,
}
if is_token:
definition["site"] = {"title": "Workflow Form"}
return definition
def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None:
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
if form.is_expired:
raise FormExpiredError()
if form.is_submitted:
raise FormAlreadySubmittedError()
self._validate_submission(form=form, submission_data=submission_data)
form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action)
self.repository.save(form)
def cleanup_expired_forms(self) -> int:
expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired]
for form_id in expired_ids:
self.repository.delete(form_id)
return len(expired_ids)
def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None:
defined_actions = {action["id"] for action in form.user_actions}
if submission_data.action not in defined_actions:
raise InvalidFormDataError(f"Invalid action: {submission_data.action}")
missing_inputs = []
for form_input in form.inputs:
if form_input.output_variable_name not in submission_data.inputs:
missing_inputs.append(form_input.output_variable_name)
if missing_inputs:
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
# Extra inputs are allowed; no further validation required.

View File

@ -5,15 +5,6 @@ Unit tests for FormService.
from datetime import datetime, timedelta
import pytest
from libs._human_input.exceptions import (
FormAlreadySubmittedError,
FormExpiredError,
FormNotFoundError,
InvalidFormDataError,
)
from libs._human_input.form_service import FormService
from libs._human_input.models import FormSubmissionData
from libs._human_input.repository import InMemoryFormRepository
from core.workflow.nodes.human_input.entities import (
FormInput,
@ -22,6 +13,16 @@ from core.workflow.nodes.human_input.entities import (
UserAction,
)
from .support import (
FormAlreadySubmittedError,
FormExpiredError,
FormNotFoundError,
FormService,
FormSubmissionData,
InMemoryFormRepository,
InvalidFormDataError,
)
class TestFormService:
"""Test FormService functionality."""

View File

@ -5,16 +5,16 @@ Unit tests for human input form models.
from datetime import datetime, timedelta
import pytest
from libs._human_input.models import FormSubmissionData, HumanInputForm
from core.workflow.nodes.human_input.entities import (
FormInput,
FormInputType,
FormSubmissionRequest,
TimeoutUnit,
UserAction,
)
from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm
class TestHumanInputForm:
"""Test HumanInputForm model."""

View File

@ -0,0 +1,205 @@
import dataclasses
from datetime import datetime
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord
from core.workflow.nodes.human_input.entities import FormDefinition, TimeoutUnit, UserAction
from models.account import Account
from models.human_input import RecipientType
from services.human_input_service import FormSubmittedError, HumanInputService
@pytest.fixture
def mock_session_factory():
session = MagicMock()
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = None
factory = MagicMock()
factory.return_value = session_cm
return factory, session
@pytest.fixture
def sample_form_record():
return HumanInputFormRecord(
form_id="form-id",
workflow_run_id="workflow-run-id",
node_id="node-id",
tenant_id="tenant-id",
definition=FormDefinition(
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
rendered_content="<p>hello</p>",
timeout=1,
timeout_unit=TimeoutUnit.HOUR,
),
rendered_content="<p>hello</p>",
expiration_time=datetime(2024, 1, 1),
selected_action_id=None,
submitted_data=None,
submitted_at=None,
submission_user_id=None,
submission_end_user_id=None,
completed_by_recipient_id=None,
recipient_id="recipient-id",
recipient_type=RecipientType.WEBAPP,
access_token="token",
)
def test_enqueue_resume_dispatches_task(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
trigger_log = MagicMock()
trigger_log.id = "trigger-log-id"
trigger_log.queue_name = "workflow_queue"
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = trigger_log
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "workflow_queue"
payload = call_kwargs["kwargs"]["task_data_dict"]
assert payload["workflow_trigger_log_id"] == "trigger-log-id"
assert payload["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_no_trigger_log(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_not_called()
def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
app = MagicMock()
app.mode = "advanced-chat"
session.get.side_effect = [workflow_run, app]
resume_task = mocker.patch("services.human_input_service.resume_chatflow_execution")
service._enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "chatflow_execute"
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormReadRepository)
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
form = service.get_form_definition_by_id("form-id")
repo.get_by_form_id_and_recipient_type.assert_called_once_with(
form_id="form-id",
recipient_type=RecipientType.WEBAPP,
)
assert form is not None
assert form.get_definition() == sample_form_record.definition
def test_get_form_definition_by_id_raises_on_submitted(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime(2024, 1, 1))
repo = MagicMock(spec=HumanInputFormReadRepository)
repo.get_by_form_id_and_recipient_type.return_value = submitted_record
service = HumanInputService(session_factory, form_repository=repo)
with pytest.raises(FormSubmittedError):
service.get_form_definition_by_id("form-id")
def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormReadRepository)
repo.get_by_token.return_value = sample_form_record
repo.mark_submitted.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "_enqueue_resume")
service.submit_form_by_token(
recipient_type=RecipientType.WEBAPP,
form_token="token",
selected_action_id="approve",
form_data={"field": "value"},
submission_end_user_id="end-user-id",
)
repo.get_by_token.assert_called_once_with("token")
repo.mark_submitted.assert_called_once()
call_kwargs = repo.mark_submitted.call_args.kwargs
assert call_kwargs["form_id"] == sample_form_record.form_id
assert call_kwargs["recipient_id"] == sample_form_record.recipient_id
assert call_kwargs["selected_action_id"] == "approve"
assert call_kwargs["form_data"] == {"field": "value"}
assert call_kwargs["submission_end_user_id"] == "end-user-id"
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
def test_submit_form_by_id_passes_account(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormReadRepository)
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
repo.mark_submitted.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "_enqueue_resume")
account = MagicMock(spec=Account)
account.id = "account-id"
service.submit_form_by_id(
form_id="form-id",
selected_action_id="approve",
form_data={"x": 1},
user=account,
)
repo.get_by_form_id_and_recipient_type.assert_called_once()
repo.mark_submitted.assert_called_once()
assert repo.mark_submitted.call_args.kwargs["submission_user_id"] == "account-id"
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)

View File

@ -35,7 +35,6 @@ class TestDataFactory:
app_id: str = "app-789",
workflow_id: str = "workflow-101",
status: str | WorkflowExecutionStatus = "paused",
pause_id: str | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowRun object."""
@ -45,7 +44,6 @@ class TestDataFactory:
mock_run.app_id = app_id
mock_run.workflow_id = workflow_id
mock_run.status = status
mock_run.pause_id = pause_id
for key, value in kwargs.items():
setattr(mock_run, key, value)

View File

@ -161,292 +161,3 @@ class TestWorkflowService:
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()
class TestWorkflowServiceHumanInputValidation:
@pytest.fixture
def workflow_service(self):
# Mock sessionmaker to avoid database dependency
mock_session_maker = MagicMock()
return WorkflowService(mock_session_maker)
def test_validate_graph_structure_valid_human_input(self, workflow_service):
"""Test validation of valid HumanInput node data."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [
{
"type": "text-input",
"output_variable_name": "user_input",
"placeholder": {"type": "constant", "value": "Enter text here"},
}
],
"user_actions": [{"id": "submit", "title": "Submit", "button_style": "primary"}],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_graph_structure_empty_graph(self, workflow_service):
"""Test validation of empty graph."""
graph = {}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_graph_structure_no_nodes(self, workflow_service):
"""Test validation of graph with no nodes."""
graph = {"nodes": []}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_graph_structure_non_human_input_node(self, workflow_service):
"""Test validation ignores non-HumanInput nodes."""
graph = {"nodes": [{"id": "node-1", "data": {"type": "start", "title": "Start"}}]}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_delivery_method_type(self, workflow_service):
"""Test validation fails with invalid delivery method type."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "invalid_type", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_form_input_type(self, workflow_service):
"""Test validation fails with invalid form input type."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [{"type": "invalid-input-type", "output_variable_name": "user_input"}],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_missing_required_fields(self, workflow_service):
"""Test validation fails with missing required fields."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
# Missing required fields like title
"delivery_methods": [],
"form_content": "",
"inputs": [],
"user_actions": [],
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_timeout_unit(self, workflow_service):
"""Test validation fails with invalid timeout unit."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "invalid_unit",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_button_style(self, workflow_service):
"""Test validation fails with invalid button style."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Please provide your input",
"inputs": [],
"user_actions": [{"id": "submit", "title": "Submit", "button_style": "invalid_style"}],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_email_delivery_config(self, workflow_service):
"""Test validation of HumanInput node with email delivery configuration."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [
{
"type": "email",
"enabled": True,
"config": {
"recipients": {
"whole_workspace": False,
"items": [{"type": "external", "email": "user@example.com"}],
},
"subject": "Input Required",
"body": "Please provide your input",
},
}
],
"form_content": "Please provide your input",
"inputs": [
{
"type": "paragraph",
"output_variable_name": "feedback",
"placeholder": {"type": "variable", "selector": ["node", "output"]},
}
],
"user_actions": [
{"id": "approve", "title": "Approve", "button_style": "accent"},
{"id": "reject", "title": "Reject", "button_style": "ghost"},
],
"timeout": 7,
"timeout_unit": "day",
},
}
]
}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_invalid_email_recipient(self, workflow_service):
"""Test validation fails with invalid email recipient."""
graph = {
"nodes": [
{
"id": "node-1",
"data": {
"type": "human_input",
"title": "Human Input",
"delivery_methods": [
{
"type": "email",
"enabled": True,
"config": {
"recipients": {
"whole_workspace": False,
"items": [{"type": "invalid_recipient_type", "email": "user@example.com"}],
},
"subject": "Input Required",
"body": "Please provide your input",
},
}
],
"form_content": "Please provide your input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
}
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)
def test_validate_human_input_node_data_multiple_nodes_mixed_valid_invalid(self, workflow_service):
"""Test validation with multiple nodes where some are valid and some invalid."""
graph = {
"nodes": [
{"id": "node-1", "data": {"type": "start", "title": "Start"}},
{
"id": "node-2",
"data": {
"type": "human_input",
"title": "Valid Human Input",
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
"form_content": "Valid input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
},
{
"id": "node-3",
"data": {
"type": "human_input",
"title": "Invalid Human Input",
"delivery_methods": [{"type": "invalid_method", "enabled": True}],
"form_content": "Invalid input",
"inputs": [],
"user_actions": [],
"timeout": 24,
"timeout_unit": "hour",
},
},
]
}
with pytest.raises(ValueError, match="Invalid HumanInput node data"):
workflow_service.validate_graph_structure(graph)