feat(api): Implement HITL for Workflow, add is_resumption for start event

This commit is contained in:
QuantumGhost
2025-12-30 16:40:08 +08:00
parent 01325c543f
commit 37dd61558c
27 changed files with 762 additions and 344 deletions

View File

@ -32,6 +32,8 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
# TODO(QuantumGhost): fix the import violation later
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
[importlinter:contract:rsc]
name = RSC

View File

@ -527,6 +527,11 @@ class WorkflowDraftRunLoopNodeApi(Resource):
raise InternalServerError()
class HumanInputSubmitPayload(BaseModel):
inputs: dict[str, Any]
action: str
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form")
class AdvancedChatDraftHumanInputFormApi(Resource):
@console_ns.doc("get_advanced_chat_draft_human_input_form")
@ -580,19 +585,14 @@ class AdvancedChatDraftHumanInputFormApi(Resource):
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("action", type=str, required=True, location="json")
)
args = parser.parse_args()
args = HumanInputSubmitPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
result = workflow_service.submit_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
form_inputs=args["inputs"],
action=args["action"],
form_inputs=args.inputs,
action=args.action,
)
return jsonable_encoder(result)
@ -650,19 +650,14 @@ class WorkflowDraftHumanInputFormApi(Resource):
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("action", type=str, required=True, location="json")
)
args = parser.parse_args()
workflow_service = WorkflowService()
args = HumanInputSubmitPayload.model_validate(console_ns.payload or {})
result = workflow_service.submit_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
form_inputs=args["inputs"],
action=args["action"],
form_inputs=args.inputs,
action=args.action,
)
return jsonable_encoder(result)

View File

@ -411,11 +411,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
if not is_paused:
return {
"is_suspended": False,
"paused_at": None,
"paused_nodes": [],
"pending_human_inputs": [],
"pause_reasons": [],
}, 200
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
@ -430,11 +427,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
# Build response
response = {
"is_suspended": True,
"paused_at": workflow_run.created_at.isoformat() + "Z" if workflow_run.created_at else None,
"paused_nodes": [],
"pending_human_inputs": [],
"pause_reasons": pause_reasons,
}
# Add pending human input forms

View File

@ -157,6 +157,7 @@ class ConsoleWorkflowEventsApi(Resource):
app = _retrieve_app_for_workflow_run(session, workflow_run)
if workflow_run.finished_at is not None:
# TODO(QuantumGhost): should we modify the handling for finished workflow run here?
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run.id,
workflow_run=workflow_run,

View File

@ -309,6 +309,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow_id,
is_resumption=event.is_resumption,
)
yield workflow_start_resp

View File

@ -197,6 +197,7 @@ class WorkflowResponseConverter:
task_id: str,
workflow_run_id: str,
workflow_id: str,
is_resumption: bool,
) -> WorkflowStartStreamResponse:
run_id = self._ensure_workflow_run_id(workflow_run_id)
started_at = naive_utc_now()
@ -210,6 +211,7 @@ class WorkflowResponseConverter:
workflow_id=workflow_id,
inputs=self._workflow_inputs,
created_at=int(started_at.timestamp()),
is_resumption=is_resumption,
),
)

View File

@ -64,6 +64,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
@ -81,6 +82,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
@ -98,6 +100,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
@ -114,6 +117,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
@ -152,7 +156,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):

View File

@ -0,0 +1,7 @@
from libs.exception import BaseHTTPException
class WorkflowPausedInBlockingModeError(BaseHTTPException):
error_code = "workflow_paused_in_blocking_mode"
description = "Workflow execution paused for human input; blocking response mode is not supported."
code = 400

View File

@ -10,6 +10,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
@ -47,6 +48,7 @@ from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@ -133,6 +135,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowPauseStreamResponse):
raise WorkflowPausedInBlockingModeError()
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@ -267,6 +271,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow.id,
is_resumption=event.is_resumption,
)
yield start_resp
@ -452,7 +457,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event=event,
task_id=self._application_generate_entity.task_id,
)
yield from response
yield from responses
def _handle_workflow_failed_and_stop_events(
self,

View File

@ -358,7 +358,7 @@ class WorkflowBasedAppRunner:
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(QueueWorkflowStartedEvent())
self._publish_event(QueueWorkflowStartedEvent(is_resumption=event.is_resumption))
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):

View File

@ -264,6 +264,10 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
# is_resumption indicating whether this `start` is a
# resumption of previously suspended execution.
is_resumption: bool = False
class QueueWorkflowSucceededEvent(AppQueueEvent):
"""

View File

@ -208,6 +208,7 @@ class WorkflowStartStreamResponse(StreamResponse):
workflow_id: str
inputs: Mapping[str, Any]
created_at: int
is_resumption: bool = False
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
workflow_run_id: str

View File

@ -235,7 +235,7 @@ class GraphEngine:
self._graph_execution.paused = False
self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent()
start_event = GraphRunStartedEvent(is_resumption=is_resume)
self._event_manager.notify_layers(start_event)
yield start_event

View File

@ -5,7 +5,9 @@ from core.workflow.graph_events import BaseGraphEvent
class GraphRunStartedEvent(BaseGraphEvent):
pass
# is_resumption indicating whether this `start` is a
# resumption of previously suspended execution.
is_resumption: bool = False
class GraphRunSucceededEvent(BaseGraphEvent):

View File

@ -1,7 +1,7 @@
"""Add human input related models
Revision ID: d411af417245
Revises: 669ffd70119c
Revises: 03ea244985ce
Create Date: 2025-11-24 03:36:50.565145
"""
@ -13,7 +13,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d411af417245"
down_revision = "669ffd70119c"
down_revision = "03ea244985ce"
branch_labels = None
depends_on = None

View File

@ -7,6 +7,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting import RateLimit
@ -17,7 +18,7 @@ from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import ChatflowExecutionParams, chatflow_execute_task
from tasks.app_generate.workflow_execute_task import AppExecutionParams, chatflow_execute_task
class AppGenerateService:
@ -85,13 +86,14 @@ class AppGenerateService:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
with rate_limit_context(rate_limit, request_id):
payload = ChatflowExecutionParams.new(
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
)
chatflow_execute_task.delay(payload.model_dump_json())
generator = AdvancedChatAppGenerator()
@ -104,6 +106,27 @@ class AppGenerateService:
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
root_node_id=root_node_id,
workflow_run_id=uuid.uuid4(),
)
chatflow_execute_task.delay(payload.model_dump_json())
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
MessageBasedAppGenerator.retrieve_events(AppMode.WORKFLOW, payload.workflow_run_id),
),
request_id,
)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
@ -112,7 +135,7 @@ class AppGenerateService:
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
streaming=False,
root_node_id=root_node_id,
call_depth=0,
),

View File

@ -2,7 +2,7 @@ import logging
from collections.abc import Mapping
from typing import Any
from sqlalchemy import Engine
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from core.repositories.human_input_reposotiry import (
@ -16,11 +16,8 @@ from libs.exception import BaseHTTPException
from models.account import Account
from models.human_input import RecipientType
from models.model import App, AppMode
from models.workflow import WorkflowRun
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.workflow.entities import WorkflowResumeTaskData
from tasks.app_generate.workflow_execute_task import resume_chatflow_execution
from tasks.async_workflow_tasks import resume_workflow_execution
from repositories.factory import DifyAPIRepositoryFactory
from tasks.app_generate.workflow_execute_task import APP_EXECUTE_QUEUE, resume_app_execution
class Form:
@ -223,51 +220,29 @@ class HumanInputService:
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
def _enqueue_resume(self, workflow_run_id: str) -> None:
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(workflow_run_id)
if workflow_run is None:
raise AssertionError(f"WorkflowRun not found, id={workflow_run_id}")
with self._session_factory(expire_on_commit=False) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id)
if trigger_log is not None:
payload = WorkflowResumeTaskData(
workflow_trigger_log_id=trigger_log.id,
workflow_run_id=workflow_run_id,
app_query = select(App).where(App.id == workflow_run.app_id)
app = session.execute(app_query).scalar_one_or_none()
if app is None:
logger.error(
"App not found for WorkflowRun, workflow_run_id=%s, app_id=%s", workflow_run_id, workflow_run.app_id
)
return
if app.mode in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
payload = {"workflow_run_id": workflow_run_id}
try:
resume_workflow_execution.apply_async(
kwargs={"task_data_dict": payload.model_dump()},
queue=trigger_log.queue_name,
resume_app_execution.apply_async(
kwargs={"payload": payload},
queue=APP_EXECUTE_QUEUE,
)
except Exception: # pragma: no cover
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
return
if self._enqueue_chatflow_resume(workflow_run_id):
return
logger.warning("No workflow trigger log bound to workflow run %s; skipping resume dispatch", workflow_run_id)
def _enqueue_chatflow_resume(self, workflow_run_id: str) -> bool:
with self._session_factory(expire_on_commit=False) as session:
workflow_run = session.get(WorkflowRun, workflow_run_id)
if workflow_run is None:
return False
app = session.get(App, workflow_run.app_id)
if app is None:
return False
if app.mode != AppMode.ADVANCED_CHAT.value:
return False
try:
resume_chatflow_execution.apply_async(
kwargs={"payload": {"workflow_run_id": workflow_run_id}},
queue="chatflow_execute",
)
except Exception: # pragma: no cover
logger.exception("Failed to enqueue chatflow resume for workflow run %s", workflow_run_id)
return False
return True
logger.warning("App mode %s does not support resume for workflow run %s", app.mode, workflow_run_id)

View File

@ -101,7 +101,6 @@ class WorkflowTaskData(BaseModel):
class WorkflowResumeTaskData(BaseModel):
"""Payload for workflow resumption tasks."""
workflow_trigger_log_id: str
workflow_run_id: str

View File

@ -14,7 +14,6 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.consts import SELECTORS_LENGTH
from core.variables.variables import VariableUnion
from core.workflow.entities import GraphInitParams, WorkflowNodeExecution
from core.workflow.entities.pause_reason import HumanInputRequired
@ -24,12 +23,13 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, N
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.human_input.entities import _OUTPUT_VARIABLE_PATTERN, HumanInputNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from enums.cloud_plan import CloudPlan
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@ -836,11 +836,19 @@ class WorkflowService:
raise ValueError(f"Missing inputs: {missing_list}")
rendered_content = node._render_form_content()
filled_inputs = dict(form_inputs)
rendered_content_with_outputs = self._render_content_with_output_values(rendered_content, filled_inputs)
outputs: dict[str, Any] = dict(filled_inputs)
outputs: dict[str, Any] = dict(form_inputs)
outputs["__action_id"] = action
rendered_content_with_outputs = rendered_content
for field_name in node_data.outputs_field_names():
placeholder = f"{{{{#$outputs.{field_name}#}}}}"
value = outputs.get(field_name)
if value is None:
replacement = ""
elif isinstance(value, (dict, list)):
replacement = json.dumps(value, ensure_ascii=False)
else:
replacement = str(value)
rendered_content_with_outputs = rendered_content_with_outputs.replace(placeholder, replacement)
outputs["__rendered_content"] = rendered_content_with_outputs
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
@ -919,42 +927,25 @@ class WorkflowService:
graph_config=workflow.graph_dict,
config=node_config,
)
selectors_to_load: list[list[str]] = []
for selector in variable_mapping.values():
if variable_pool.get(selector) is None:
selectors_to_load.append(list(selector))
loaded_variables = variable_loader.load_variables(selectors_to_load)
for variable in loaded_variables:
variable_pool.add([variable.selector[0], variable.selector[1]], variable)
normalized_user_inputs: dict[str, Any] = dict(manual_inputs)
for raw_key, value in manual_inputs.items():
selector = self._parse_selector(raw_key)
variable_pool.add(selector, value)
normalized_user_inputs[f"#{raw_key}#"] = value
load_into_variable_pool(
variable_loader=variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=normalized_user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=normalized_user_inputs,
variable_pool=variable_pool,
tenant_id=app_model.tenant_id,
)
return variable_pool
def _parse_selector(self, selector_key: str) -> list[str]:
cleaned = selector_key.strip()
if cleaned.startswith("#") and cleaned.endswith("#"):
cleaned = cleaned[1:-1]
selector = cleaned.split(".")
if len(selector) != SELECTORS_LENGTH:
raise ValueError(f"Invalid selector '{selector_key}', expected format '<node_id>.<variable_name>'.")
return selector
def _render_content_with_output_values(self, content: str, outputs: Mapping[str, Any]) -> str:
def _replace(match):
field_name = match.group("field_name")
value = outputs.get(field_name)
if value is None:
return ""
if isinstance(value, (dict, list)):
return json.dumps(value, ensure_ascii=False)
return str(value)
return _OUTPUT_VARIABLE_PATTERN.sub(_replace, content)
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:

View File

@ -1,3 +1,3 @@
from .workflow_execute_task import chatflow_execute_task
from .workflow_execute_task import AppExecutionParams, chatflow_execute_task, resume_app_execution
__all__ = ["chatflow_execute_task"]
__all__ = ["AppExecutionParams", "chatflow_execute_task", "resume_app_execution"]

View File

@ -12,7 +12,13 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.runtime import GraphRuntimeState
@ -26,6 +32,8 @@ from repositories.factory import DifyAPIRepositoryFactory
logger = logging.getLogger(__name__)
APP_EXECUTE_QUEUE = "chatflow_execute"
class _UserType(StrEnum):
ACCOUNT = "account"
@ -66,16 +74,18 @@ User: TypeAlias = Annotated[
]
class ChatflowExecutionParams(BaseModel):
class AppExecutionParams(BaseModel):
app_id: str
workflow_id: str
tenant_id: str
app_mode: AppMode = AppMode.ADVANCED_CHAT
user: User
args: Mapping[str, Any]
invoke_from: InvokeFrom
streaming: bool = True
call_depth: int = 0
root_node_id: str | None = None
workflow_run_id: uuid.UUID = Field(default_factory=uuid.uuid4)
@classmethod
@ -87,6 +97,9 @@ class ChatflowExecutionParams(BaseModel):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
root_node_id: str | None = None,
workflow_run_id: uuid.UUID | None = None,
):
user_params: _Account | _EndUser
if isinstance(user, Account):
@ -99,16 +112,19 @@ class ChatflowExecutionParams(BaseModel):
app_id=app_model.id,
workflow_id=workflow.id,
tenant_id=app_model.tenant_id,
app_mode=AppMode.value_of(app_model.mode),
user=user_params,
args=args,
invoke_from=invoke_from,
streaming=streaming,
workflow_run_id=uuid.uuid4(),
call_depth=call_depth,
root_node_id=root_node_id,
workflow_run_id=workflow_run_id or uuid.uuid4(),
)
class _ChatflowRunner:
def __init__(self, session_factory: sessionmaker | Engine, exec_params: ChatflowExecutionParams):
class _AppRunner:
def __init__(self, session_factory: sessionmaker | Engine, exec_params: AppExecutionParams):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
@ -130,7 +146,13 @@ class _ChatflowRunner:
exec_params = self._exec_params
with self._session() as session:
workflow = session.get(Workflow, exec_params.workflow_id)
if workflow is None:
logger.warning("Workflow %s not found for execution", exec_params.workflow_id)
return None
app = session.get(App, workflow.app_id)
if app is None:
logger.warning("App %s not found for workflow %s", workflow.app_id, exec_params.workflow_id)
return None
pause_config = PauseStateLayerConfig(
session_factory=self._session_factory,
@ -139,25 +161,54 @@ class _ChatflowRunner:
user = self._resolve_user()
chat_generator = AdvancedChatAppGenerator()
workflow_run_id = exec_params.workflow_run_id
with self._setup_flask_context(user):
response = chat_generator.generate(
response = self._run_app(
app=app,
workflow=workflow,
user=user,
pause_state_config=pause_config,
)
if not exec_params.streaming:
return response
_publish_streaming_response(response, exec_params.workflow_run_id, exec_params.app_mode)
def _run_app(
self,
*,
app: App,
workflow: Workflow,
user: Account | EndUser,
pause_state_config: PauseStateLayerConfig,
):
exec_params = self._exec_params
if exec_params.app_mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args=exec_params.args,
invoke_from=exec_params.invoke_from,
streaming=exec_params.streaming,
workflow_run_id=workflow_run_id,
pause_state_config=pause_config,
workflow_run_id=exec_params.workflow_run_id,
pause_state_config=pause_state_config,
)
if exec_params.app_mode == AppMode.WORKFLOW:
return WorkflowAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args=exec_params.args,
invoke_from=exec_params.invoke_from,
streaming=exec_params.streaming,
call_depth=exec_params.call_depth,
root_node_id=exec_params.root_node_id,
workflow_run_id=exec_params.workflow_run_id,
pause_state_config=pause_state_config,
)
if not exec_params.streaming:
return response
_publish_streaming_response(response, workflow_run_id)
logger.error("Unsupported app mode for execution: %s", exec_params.app_mode)
return None
def _resolve_user(self) -> Account | EndUser:
user_params = self._exec_params.user
@ -199,13 +250,13 @@ def _coerce_uuid(value: Any) -> uuid.UUID | None:
return None
def _publish_streaming_response(response_stream: Iterable[Any], workflow_run_id: Any) -> None:
def _publish_streaming_response(response_stream: Iterable[Any], workflow_run_id: Any, app_mode: AppMode) -> None:
workflow_run_uuid = _coerce_uuid(workflow_run_id)
if workflow_run_uuid is None:
logger.warning("Unable to publish streaming response without valid workflow_run_id: %s", workflow_run_id)
return
topic = AdvancedChatAppGenerator.get_response_topic(AppMode.ADVANCED_CHAT, workflow_run_uuid)
topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_uuid)
for event in response_stream:
try:
payload = json.dumps(event)
@ -216,18 +267,17 @@ def _publish_streaming_response(response_stream: Iterable[Any], workflow_run_id:
topic.publish(payload.encode())
@shared_task(queue="chatflow_execute")
@shared_task(queue=APP_EXECUTE_QUEUE)
def chatflow_execute_task(payload: str) -> Mapping[str, Any] | None:
exec_params = ChatflowExecutionParams.model_validate_json(payload)
exec_params = AppExecutionParams.model_validate_json(payload)
logger.info("chatflow_execute_task run with params: %s", exec_params)
runner = _ChatflowRunner(db.engine, exec_params=exec_params)
runner = _AppRunner(db.engine, exec_params=exec_params)
return runner.run()
@shared_task(queue="chatflow_execute", name="resume_chatflow_execution")
def resume_chatflow_execution(payload: dict[str, Any]) -> None:
def _resume_app_execution(payload: dict[str, Any]) -> None:
workflow_run_id = payload["workflow_run_id"]
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
@ -245,16 +295,11 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
return
generate_entity = resumption_context.get_generate_entity()
if not isinstance(generate_entity, AdvancedChatAppGenerateEntity):
logger.error(
"Resumption entity is not AdvancedChatAppGenerateEntity for workflow run %s (found %s)",
workflow_run_id,
type(generate_entity),
)
return
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
conversation = None
message = None
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = session.get(WorkflowRun, workflow_run_id)
if workflow_run is None:
@ -271,29 +316,38 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
logger.warning("App %s not found during resume", workflow_run.app_id)
return
if generate_entity.conversation_id is None:
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
return
conversation = session.get(Conversation, generate_entity.conversation_id)
if conversation is None:
logger.warning(
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
)
return
message = session.scalar(
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
)
if message is None:
logger.warning("Message not found for workflow run %s", workflow_run_id)
return
user = _resolve_user_for_run(session, workflow_run)
if user is None:
logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id)
return
if isinstance(generate_entity, AdvancedChatAppGenerateEntity):
if generate_entity.conversation_id is None:
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
return
conversation = session.get(Conversation, generate_entity.conversation_id)
if conversation is None:
logger.warning(
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
)
return
message = session.scalar(
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
)
if message is None:
logger.warning("Message not found for workflow run %s", workflow_run_id)
return
if not isinstance(generate_entity, (AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity)):
logger.error(
"Unsupported resumption entity for workflow run %s (found %s)",
workflow_run_id,
type(generate_entity),
)
return
workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity)
pause_config = PauseStateLayerConfig(
@ -301,6 +355,52 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
state_owner_user_id=workflow.created_by,
)
if isinstance(generate_entity, AdvancedChatAppGenerateEntity):
assert conversation is not None
assert message is not None
_resume_advanced_chat(
app_model=app_model,
workflow=workflow,
user=user,
conversation=conversation,
message=message,
generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
session_factory=session_factory,
pause_state_config=pause_config,
workflow_run_id=workflow_run_id,
workflow_run=workflow_run,
)
elif isinstance(generate_entity, WorkflowAppGenerateEntity):
_resume_workflow(
app_model=app_model,
workflow=workflow,
user=user,
generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
session_factory=session_factory,
pause_state_config=pause_config,
workflow_run_id=workflow_run_id,
workflow_run=workflow_run,
workflow_run_repo=workflow_run_repo,
pause_entity=pause_entity,
)
def _resume_advanced_chat(
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
conversation: Conversation,
message: Message,
generate_entity: AdvancedChatAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
session_factory: sessionmaker,
pause_state_config: PauseStateLayerConfig,
workflow_run_id: str,
workflow_run: WorkflowRun,
) -> None:
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
@ -332,7 +432,7 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_runtime_state=graph_runtime_state,
pause_state_config=pause_config,
pause_state_config=pause_state_config,
)
except Exception:
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
@ -346,4 +446,76 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
workflow_run_id,
)
else:
_publish_streaming_response(response, publish_uuid)
_publish_streaming_response(response, publish_uuid, AppMode.ADVANCED_CHAT)
def _resume_workflow(
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
session_factory: sessionmaker,
pause_state_config: PauseStateLayerConfig,
workflow_run_id: str,
workflow_run: WorkflowRun,
workflow_run_repo,
pause_entity,
) -> None:
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
generator = WorkflowAppGenerator()
try:
response = generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
pause_state_config=pause_state_config,
)
except Exception:
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
publish_uuid = _coerce_uuid(generate_entity.workflow_execution_id) or _coerce_uuid(workflow_run_id)
if publish_uuid is None:
logger.warning(
"Unable to publish streaming response for workflow run %s due to missing workflow_run_id",
workflow_run_id,
)
else:
_publish_streaming_response(response, publish_uuid, AppMode.WORKFLOW)
workflow_run_repo.delete_workflow_pause(pause_entity)
@shared_task(queue=APP_EXECUTE_QUEUE, name="resume_app_execution")
def resume_app_execution(payload: dict[str, Any]) -> None:
_resume_app_execution(payload)
@shared_task(queue=APP_EXECUTE_QUEUE, name="resume_chatflow_execution")
def resume_chatflow_execution(payload: dict[str, Any]) -> None:
_resume_app_execution(payload)

View File

@ -26,7 +26,7 @@ from models.account import Account
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import WorkflowNotFoundError
@ -40,12 +40,6 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf
logger = logging.getLogger(__name__)
_TRIGGER_TO_RUN_SOURCE = {
AppTriggerType.TRIGGER_WEBHOOK: WorkflowRunTriggeredFrom.WEBHOOK,
AppTriggerType.TRIGGER_SCHEDULE: WorkflowRunTriggeredFrom.SCHEDULE,
AppTriggerType.TRIGGER_PLUGIN: WorkflowRunTriggeredFrom.PLUGIN,
}
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]):
@ -204,44 +198,135 @@ def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id)
if pause_entity is None:
logger.warning("No pause state for workflow run %s", task_data.workflow_run_id)
return
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(pause_entity.workflow_execution_id)
if workflow_run is None:
logger.warning("Workflow run not found for pause entity: pause_entity_id=%s", pause_entity.id)
return
try:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
except Exception as exc:
logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id)
raise exc
generate_entity = resumption_context.get_generate_entity()
if not isinstance(generate_entity, WorkflowAppGenerateEntity):
logger.error(
"Unsupported resumption entity for workflow run %s: %s",
task_data.workflow_run_id,
type(generate_entity),
)
return
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
with session_factory() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
if not trigger_log:
logger.warning("Trigger log not found for resumption: %s", task_data.workflow_trigger_log_id)
return
pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id)
if pause_entity is None:
logger.warning("No pause state for workflow run %s", task_data.workflow_run_id)
return
try:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
except Exception as exc:
logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id)
raise exc
generate_entity = resumption_context.get_generate_entity()
if not isinstance(generate_entity, WorkflowAppGenerateEntity):
logger.error(
"Unsupported resumption entity for workflow run %s: %s",
task_data.workflow_run_id,
type(generate_entity),
)
return
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
workflow = session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if workflow is None:
raise WorkflowNotFoundError(f"Workflow not found: {trigger_log.workflow_id}")
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
raise WorkflowNotFoundError(
"Workflow not found: workflow_run_id=%s, workflow_id=%s", workflow_run.id, workflow_run.workflow_id
)
user = _get_user(session, workflow_run)
app_model = session.scalar(select(App).where(App.id == workflow_run.app_id))
if app_model is None:
raise WorkflowNotFoundError(f"App not found: {trigger_log.app_id}")
raise _AppNotFoundError(
"App not found: app_id=%s, workflow_run_id=%s", workflow_run.app_id, workflow_run.id
)
user = _get_user(session, trigger_log)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom(workflow_run.triggered_from),
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
generator = WorkflowAppGenerator()
start_time = datetime.now(UTC)
graph_engine_layers = []
trigger_log = _query_trigger_log_info(session_factory, task_data.workflow_run_id)
if trigger_log:
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
queue=AsyncWorkflowQueue(trigger_log.queue_name),
schedule_strategy=AsyncWorkflowSystemStrategy,
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
)
cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity)
graph_engine_layers.extend(
[
TimeSliceLayer(cfs_plan_scheduler),
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
]
)
workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity)
generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_config,
)
workflow_run_repo.delete_workflow_pause(pause_entity)
def _get_user(session: Session, workflow_run: WorkflowRun) -> Account | EndUser:
"""Compose user from trigger log"""
tenant = session.scalar(select(Tenant).where(Tenant.id == workflow_run.tenant_id))
if not tenant:
raise _TenantNotFoundError(
"Tenant not found for WorkflowRun: tenant_id=%s, workflow_run_id=%s",
workflow_run.tenant_id,
workflow_run.id,
)
# Get user from trigger log
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
user = session.scalar(select(Account).where(Account.id == workflow_run.created_by))
if user:
user.current_tenant = tenant
else: # CreatorUserRole.END_USER
user = session.scalar(select(EndUser).where(EndUser.id == workflow_run.created_by))
if not user:
raise _UserNotFoundError(
"User not found: user_id=%s, created_by_role=%s, workflow_run_id=%s",
workflow_run.created_by,
workflow_run.created_by_role,
workflow_run.id,
)
return user
def _query_trigger_log_info(session_factory: sessionmaker[Session], workflow_run_id) -> WorkflowTriggerLog | None:
with session_factory() as session, session.begin():
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id)
if not trigger_log:
logger.debug("Trigger log not found for workflow_run: workflow_run_id=%s", workflow_run_id)
return
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
queue=trigger_log.queue_name,
@ -255,74 +340,14 @@ def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
except ValueError:
trigger_type = AppTriggerType.UNKNOWN
triggered_from = _TRIGGER_TO_RUN_SOURCE.get(trigger_type, WorkflowRunTriggeredFrom.APP_RUN)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory,
state_owner_user_id=workflow.created_by,
)
workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity)
trigger_log.status = WorkflowTriggerStatus.RUNNING
trigger_log_repo.update(trigger_log)
session.commit()
generator = WorkflowAppGenerator()
start_time = datetime.now(UTC)
try:
generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=[
TimeSliceLayer(cfs_plan_scheduler),
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
],
pause_state_config=pause_config,
)
except Exception as exc:
trigger_log.status = WorkflowTriggerStatus.FAILED
trigger_log.error = str(exc)
trigger_log.finished_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
raise
class _TenantNotFoundError(Exception):
pass
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
"""Compose user from trigger log"""
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
if not tenant:
raise ValueError(f"Tenant not found: {trigger_log.tenant_id}")
class _UserNotFoundError(Exception):
pass
# Get user from trigger log
if trigger_log.created_by_role == CreatorUserRole.ACCOUNT:
user = session.scalar(select(Account).where(Account.id == trigger_log.created_by))
if user:
user.current_tenant = tenant
else: # CreatorUserRole.END_USER
user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by))
if not user:
raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})")
return user
class _AppNotFoundError(Exception):
pass

View File

@ -0,0 +1,160 @@
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from controllers.console import wraps as console_wraps
from controllers.console.app import workflow as workflow_module
from controllers.console.app import wraps as app_wraps
from libs import login as login_lib
from models.account import Account, AccountStatus, TenantAccountRole
from models.model import AppMode
def _make_account() -> Account:
account = Account(name="tester", email="tester@example.com")
account.status = AccountStatus.ACTIVE
account.role = TenantAccountRole.OWNER
account.id = "account-123" # type: ignore[assignment]
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
account._get_current_object = lambda: account # type: ignore[attr-defined]
return account
def _make_app(mode: AppMode) -> SimpleNamespace:
return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value)
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None:
# Skip setup and auth guardrails
monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD")
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
monkeypatch.setattr(login_lib, "current_user", account)
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
monkeypatch.delenv("INIT_PASSWORD", raising=False)
# Avoid hitting the database when resolving the app model
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
@dataclass
class PreviewCase:
resource_cls: type
path: str
mode: AppMode
@pytest.mark.parametrize(
"case",
[
PreviewCase(
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormApi,
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form",
mode=AppMode.ADVANCED_CHAT,
),
PreviewCase(
resource_cls=workflow_module.WorkflowDraftHumanInputFormApi,
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form",
mode=AppMode.WORKFLOW,
),
],
)
def test_human_input_preview_delegates_to_service(
app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase
) -> None:
account = _make_account()
app_model = _make_app(case.mode)
_patch_console_guards(monkeypatch, account, app_model)
preview_payload = {
"form_id": "node-42",
"form_content": "<div>example</div>",
"inputs": [{"name": "topic"}],
"actions": [{"id": "continue"}],
}
service_instance = MagicMock()
service_instance.get_human_input_form_preview.return_value = preview_payload
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
with app.test_request_context(case.path, method="GET", json={"inputs": {"topic": "tech"}}):
response = case.resource_cls().get(app_id=app_model.id, node_id="node-42")
assert response == preview_payload
service_instance.get_human_input_form_preview.assert_called_once_with(
app_model=app_model,
account=account,
node_id="node-42",
manual_inputs={"topic": "tech"},
)
@dataclass
class SubmitCase:
resource_cls: type
path: str
mode: AppMode
@pytest.mark.parametrize(
"case",
[
SubmitCase(
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormApi,
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-99/form",
mode=AppMode.ADVANCED_CHAT,
),
SubmitCase(
resource_cls=workflow_module.WorkflowDraftHumanInputFormApi,
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-99/form",
mode=AppMode.WORKFLOW,
),
],
)
def test_human_input_submit_forwards_payload(app: Flask, monkeypatch: pytest.MonkeyPatch, case: SubmitCase) -> None:
account = _make_account()
app_model = _make_app(case.mode)
_patch_console_guards(monkeypatch, account, app_model)
result_payload = {"node_id": "node-99", "outputs": {"__rendered_content": "<p>done</p>"}, "action": "approve"}
service_instance = MagicMock()
service_instance.submit_human_input_form_preview.return_value = result_payload
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
with app.test_request_context(
case.path,
method="POST",
json={"inputs": {"answer": "42"}, "action": "approve"},
):
response = case.resource_cls().post(app_id=app_model.id, node_id="node-99")
assert response == result_payload
service_instance.submit_human_input_form_preview.assert_called_once_with(
app_model=app_model,
account=account,
node_id="node-99",
form_inputs={"answer": "42"},
action="approve",
)
def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
account = _make_account()
app_model = _make_app(AppMode.ADVANCED_CHAT)
_patch_console_guards(monkeypatch, account, app_model)
with app.test_request_context(
"/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form",
method="GET",
json={"inputs": ["not-a-dict"]},
):
with pytest.raises(ValueError):
workflow_module.AdvancedChatDraftHumanInputFormApi().get(app_id=app_model.id, node_id="node-1")

View File

@ -124,7 +124,12 @@ class TestWorkflowResponseConverter:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -160,7 +165,12 @@ class TestWorkflowResponseConverter:
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -191,7 +201,12 @@ class TestWorkflowResponseConverter:
"""Test node finish response when process_data is None."""
converter = self.create_workflow_response_converter()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -225,7 +240,12 @@ class TestWorkflowResponseConverter:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -261,7 +281,12 @@ class TestWorkflowResponseConverter:
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -400,6 +425,7 @@ class TestWorkflowResponseConverterServiceApiTruncation:
task_id="test-task-id",
workflow_run_id="test-workflow-run-id",
workflow_id="test-workflow-id",
is_resumption=False,
)
return converter

View File

@ -112,7 +112,12 @@ def _build_converter():
def test_queue_workflow_paused_event_to_stream_responses():
converter = _build_converter()
converter.workflow_start_to_stream_response(task_id="task", workflow_run_id="run-id", workflow_id="workflow-id")
converter.workflow_start_to_stream_response(
task_id="task",
workflow_run_id="run-id",
workflow_id="workflow-id",
is_resumption=False,
)
reason = HumanInputRequired(
form_id="form-1",

View File

@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation:
within conversations.
"""
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_without_first_id(
self, mock_get_conversation, mock_db_session, mock_create_extra_repo
):
"""
Test message pagination without specifying first_id.
@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method without first_id
result = MessageService.pagination_by_first_id(
@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation:
# Verify conversation was looked up with correct parameters
mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with first_id specified.
@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation:
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.first.return_value = first_message # First message returned
mock_query.all.return_value = messages # Remaining messages returned
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method with first_id
result = MessageService.pagination_by_first_id(
@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation:
assert result.data == []
assert result.has_more is False
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test that has_more flag is correctly set when there are more messages.
@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(
@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == limit # Extra message should be removed
assert result.has_more is True # Flag should be set
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with ascending order.
@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(

View File

@ -65,72 +65,25 @@ def sample_form_record():
)
def test_enqueue_resume_dispatches_task(mocker, mock_session_factory):
def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
trigger_log = MagicMock()
trigger_log.id = "trigger-log-id"
trigger_log.queue_name = "workflow_queue"
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = trigger_log
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "workflow_queue"
payload = call_kwargs["kwargs"]["task_data_dict"]
assert payload["workflow_trigger_log_id"] == "trigger-log-id"
assert payload["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_no_trigger_log(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_not_called()
def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "advanced-chat"
app.mode = "workflow"
session.execute.return_value.scalar_one_or_none.return_value = app
session.get.side_effect = [workflow_run, app]
resume_task = mocker.patch("services.human_input_service.resume_chatflow_execution")
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service._enqueue_resume("workflow-run-id")
@ -140,6 +93,59 @@ def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "advanced-chat"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service._enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "chatflow_execute"
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "completion"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service._enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_not_called()
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)