mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
feat(api): Implement HITL for Workflow, add is_resumption for start event
This commit is contained in:
@ -32,6 +32,8 @@ ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
|
||||
@ -527,6 +527,11 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class HumanInputSubmitPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
action: str
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form")
|
||||
class AdvancedChatDraftHumanInputFormApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_draft_human_input_form")
|
||||
@ -580,19 +585,14 @@ class AdvancedChatDraftHumanInputFormApi(Resource):
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("action", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = HumanInputSubmitPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service = WorkflowService()
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
form_inputs=args["inputs"],
|
||||
action=args["action"],
|
||||
form_inputs=args.inputs,
|
||||
action=args.action,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@ -650,19 +650,14 @@ class WorkflowDraftHumanInputFormApi(Resource):
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("action", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputSubmitPayload.model_validate(console_ns.payload or {})
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
form_inputs=args["inputs"],
|
||||
action=args["action"],
|
||||
form_inputs=args.inputs,
|
||||
action=args.action,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
@ -411,11 +411,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
if not is_paused:
|
||||
return {
|
||||
"is_suspended": False,
|
||||
"paused_at": None,
|
||||
"paused_nodes": [],
|
||||
"pending_human_inputs": [],
|
||||
"pause_reasons": [],
|
||||
}, 200
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
@ -430,11 +427,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
|
||||
# Build response
|
||||
response = {
|
||||
"is_suspended": True,
|
||||
"paused_at": workflow_run.created_at.isoformat() + "Z" if workflow_run.created_at else None,
|
||||
"paused_nodes": [],
|
||||
"pending_human_inputs": [],
|
||||
"pause_reasons": pause_reasons,
|
||||
}
|
||||
|
||||
# Add pending human input forms
|
||||
|
||||
@ -157,6 +157,7 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
app = _retrieve_app_for_workflow_run(session, workflow_run)
|
||||
|
||||
if workflow_run.finished_at is not None:
|
||||
# TODO(QuantumGhost): should we modify the handling for finished workflow run here?
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run.id,
|
||||
workflow_run=workflow_run,
|
||||
|
||||
@ -309,6 +309,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow_id,
|
||||
is_resumption=event.is_resumption,
|
||||
)
|
||||
|
||||
yield workflow_start_resp
|
||||
|
||||
@ -197,6 +197,7 @@ class WorkflowResponseConverter:
|
||||
task_id: str,
|
||||
workflow_run_id: str,
|
||||
workflow_id: str,
|
||||
is_resumption: bool,
|
||||
) -> WorkflowStartStreamResponse:
|
||||
run_id = self._ensure_workflow_run_id(workflow_run_id)
|
||||
started_at = naive_utc_now()
|
||||
@ -210,6 +211,7 @@ class WorkflowResponseConverter:
|
||||
workflow_id=workflow_id,
|
||||
inputs=self._workflow_inputs,
|
||||
created_at=int(started_at.timestamp()),
|
||||
is_resumption=is_resumption,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -64,6 +64,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
@ -81,6 +82,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
@ -98,6 +100,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
@ -114,6 +117,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
@ -152,7 +156,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
|
||||
7
api/core/app/apps/workflow/errors.py
Normal file
7
api/core/app/apps/workflow/errors.py
Normal file
@ -0,0 +1,7 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class WorkflowPausedInBlockingModeError(BaseHTTPException):
|
||||
error_code = "workflow_paused_in_blocking_mode"
|
||||
description = "Workflow execution paused for human input; blocking response mode is not supported."
|
||||
code = 400
|
||||
@ -10,6 +10,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -47,6 +48,7 @@ from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@ -133,6 +135,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
raise WorkflowPausedInBlockingModeError()
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -267,6 +271,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
is_resumption=event.is_resumption,
|
||||
)
|
||||
yield start_resp
|
||||
|
||||
@ -452,7 +457,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
)
|
||||
yield from response
|
||||
yield from responses
|
||||
|
||||
def _handle_workflow_failed_and_stop_events(
|
||||
self,
|
||||
|
||||
@ -358,7 +358,7 @@ class WorkflowBasedAppRunner:
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(QueueWorkflowStartedEvent())
|
||||
self._publish_event(QueueWorkflowStartedEvent(is_resumption=event.is_resumption))
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
|
||||
@ -264,6 +264,10 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
|
||||
# is_resumption indicating whether this `start` is a
|
||||
# resumption of previously suspended execution.
|
||||
is_resumption: bool = False
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
@ -208,6 +208,7 @@ class WorkflowStartStreamResponse(StreamResponse):
|
||||
workflow_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
created_at: int
|
||||
is_resumption: bool = False
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
|
||||
workflow_run_id: str
|
||||
|
||||
@ -235,7 +235,7 @@ class GraphEngine:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reasons = []
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
start_event = GraphRunStartedEvent(is_resumption=is_resume)
|
||||
self._event_manager.notify_layers(start_event)
|
||||
yield start_event
|
||||
|
||||
|
||||
@ -5,7 +5,9 @@ from core.workflow.graph_events import BaseGraphEvent
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
# is_resumption indicating whether this `start` is a
|
||||
# resumption of previously suspended execution.
|
||||
is_resumption: bool = False
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Add human input related models
|
||||
|
||||
Revision ID: d411af417245
|
||||
Revises: 669ffd70119c
|
||||
Revises: 03ea244985ce
|
||||
Create Date: 2025-11-24 03:36:50.565145
|
||||
|
||||
"""
|
||||
@ -13,7 +13,7 @@ import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d411af417245"
|
||||
down_revision = "669ffd70119c"
|
||||
down_revision = "03ea244985ce"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting import RateLimit
|
||||
@ -17,7 +18,7 @@ from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import ChatflowExecutionParams, chatflow_execute_task
|
||||
from tasks.app_generate.workflow_execute_task import AppExecutionParams, chatflow_execute_task
|
||||
|
||||
|
||||
class AppGenerateService:
|
||||
@ -85,13 +86,14 @@ class AppGenerateService:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = ChatflowExecutionParams.new(
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
)
|
||||
chatflow_execute_task.delay(payload.model_dump_json())
|
||||
generator = AdvancedChatAppGenerator()
|
||||
@ -104,6 +106,27 @@ class AppGenerateService:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
if streaming:
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
root_node_id=root_node_id,
|
||||
workflow_run_id=uuid.uuid4(),
|
||||
)
|
||||
chatflow_execute_task.delay(payload.model_dump_json())
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
MessageBasedAppGenerator.retrieve_events(AppMode.WORKFLOW, payload.workflow_run_id),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
@ -112,7 +135,7 @@ class AppGenerateService:
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
streaming=False,
|
||||
root_node_id=root_node_id,
|
||||
call_depth=0,
|
||||
),
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repositories.human_input_reposotiry import (
|
||||
@ -16,11 +16,8 @@ from libs.exception import BaseHTTPException
|
||||
from models.account import Account
|
||||
from models.human_input import RecipientType
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.workflow.entities import WorkflowResumeTaskData
|
||||
from tasks.app_generate.workflow_execute_task import resume_chatflow_execution
|
||||
from tasks.async_workflow_tasks import resume_workflow_execution
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from tasks.app_generate.workflow_execute_task import APP_EXECUTE_QUEUE, resume_app_execution
|
||||
|
||||
|
||||
class Form:
|
||||
@ -223,51 +220,29 @@ class HumanInputService:
|
||||
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
|
||||
|
||||
def _enqueue_resume(self, workflow_run_id: str) -> None:
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
|
||||
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(workflow_run_id)
|
||||
|
||||
if workflow_run is None:
|
||||
raise AssertionError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id)
|
||||
|
||||
if trigger_log is not None:
|
||||
payload = WorkflowResumeTaskData(
|
||||
workflow_trigger_log_id=trigger_log.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
app_query = select(App).where(App.id == workflow_run.app_id)
|
||||
app = session.execute(app_query).scalar_one_or_none()
|
||||
if app is None:
|
||||
logger.error(
|
||||
"App not found for WorkflowRun, workflow_run_id=%s, app_id=%s", workflow_run_id, workflow_run.app_id
|
||||
)
|
||||
return
|
||||
|
||||
if app.mode in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
payload = {"workflow_run_id": workflow_run_id}
|
||||
try:
|
||||
resume_workflow_execution.apply_async(
|
||||
kwargs={"task_data_dict": payload.model_dump()},
|
||||
queue=trigger_log.queue_name,
|
||||
resume_app_execution.apply_async(
|
||||
kwargs={"payload": payload},
|
||||
queue=APP_EXECUTE_QUEUE,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
if self._enqueue_chatflow_resume(workflow_run_id):
|
||||
return
|
||||
|
||||
logger.warning("No workflow trigger log bound to workflow run %s; skipping resume dispatch", workflow_run_id)
|
||||
|
||||
def _enqueue_chatflow_resume(self, workflow_run_id: str) -> bool:
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is None:
|
||||
return False
|
||||
|
||||
app = session.get(App, workflow_run.app_id)
|
||||
|
||||
if app is None:
|
||||
return False
|
||||
|
||||
if app.mode != AppMode.ADVANCED_CHAT.value:
|
||||
return False
|
||||
|
||||
try:
|
||||
resume_chatflow_execution.apply_async(
|
||||
kwargs={"payload": {"workflow_run_id": workflow_run_id}},
|
||||
queue="chatflow_execute",
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to enqueue chatflow resume for workflow run %s", workflow_run_id)
|
||||
return False
|
||||
|
||||
return True
|
||||
logger.warning("App mode %s does not support resume for workflow run %s", app.mode, workflow_run_id)
|
||||
|
||||
@ -101,7 +101,6 @@ class WorkflowTaskData(BaseModel):
|
||||
class WorkflowResumeTaskData(BaseModel):
|
||||
"""Payload for workflow resumption tasks."""
|
||||
|
||||
workflow_trigger_log_id: str
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.file import File
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
@ -24,12 +23,13 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, N
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.human_input.entities import _OUTPUT_VARIABLE_PATTERN, HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
@ -836,11 +836,19 @@ class WorkflowService:
|
||||
raise ValueError(f"Missing inputs: {missing_list}")
|
||||
|
||||
rendered_content = node._render_form_content()
|
||||
filled_inputs = dict(form_inputs)
|
||||
rendered_content_with_outputs = self._render_content_with_output_values(rendered_content, filled_inputs)
|
||||
|
||||
outputs: dict[str, Any] = dict(filled_inputs)
|
||||
outputs: dict[str, Any] = dict(form_inputs)
|
||||
outputs["__action_id"] = action
|
||||
rendered_content_with_outputs = rendered_content
|
||||
for field_name in node_data.outputs_field_names():
|
||||
placeholder = f"{{{{#$outputs.{field_name}#}}}}"
|
||||
value = outputs.get(field_name)
|
||||
if value is None:
|
||||
replacement = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
replacement = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
replacement = str(value)
|
||||
rendered_content_with_outputs = rendered_content_with_outputs.replace(placeholder, replacement)
|
||||
outputs["__rendered_content"] = rendered_content_with_outputs
|
||||
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
@ -919,42 +927,25 @@ class WorkflowService:
|
||||
graph_config=workflow.graph_dict,
|
||||
config=node_config,
|
||||
)
|
||||
selectors_to_load: list[list[str]] = []
|
||||
for selector in variable_mapping.values():
|
||||
if variable_pool.get(selector) is None:
|
||||
selectors_to_load.append(list(selector))
|
||||
|
||||
loaded_variables = variable_loader.load_variables(selectors_to_load)
|
||||
for variable in loaded_variables:
|
||||
variable_pool.add([variable.selector[0], variable.selector[1]], variable)
|
||||
|
||||
normalized_user_inputs: dict[str, Any] = dict(manual_inputs)
|
||||
for raw_key, value in manual_inputs.items():
|
||||
selector = self._parse_selector(raw_key)
|
||||
variable_pool.add(selector, value)
|
||||
normalized_user_inputs[f"#{raw_key}#"] = value
|
||||
|
||||
load_into_variable_pool(
|
||||
variable_loader=variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=normalized_user_inputs,
|
||||
)
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=normalized_user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=app_model.tenant_id,
|
||||
)
|
||||
|
||||
return variable_pool
|
||||
|
||||
def _parse_selector(self, selector_key: str) -> list[str]:
|
||||
cleaned = selector_key.strip()
|
||||
if cleaned.startswith("#") and cleaned.endswith("#"):
|
||||
cleaned = cleaned[1:-1]
|
||||
selector = cleaned.split(".")
|
||||
if len(selector) != SELECTORS_LENGTH:
|
||||
raise ValueError(f"Invalid selector '{selector_key}', expected format '<node_id>.<variable_name>'.")
|
||||
return selector
|
||||
|
||||
def _render_content_with_output_values(self, content: str, outputs: Mapping[str, Any]) -> str:
|
||||
def _replace(match):
|
||||
field_name = match.group("field_name")
|
||||
value = outputs.get(field_name)
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
return str(value)
|
||||
|
||||
return _OUTPUT_VARIABLE_PATTERN.sub(_replace, content)
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from .workflow_execute_task import chatflow_execute_task
|
||||
from .workflow_execute_task import AppExecutionParams, chatflow_execute_task, resume_app_execution
|
||||
|
||||
__all__ = ["chatflow_execute_task"]
|
||||
__all__ = ["AppExecutionParams", "chatflow_execute_task", "resume_app_execution"]
|
||||
|
||||
@ -12,7 +12,13 @@ from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
@ -26,6 +32,8 @@ from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_EXECUTE_QUEUE = "chatflow_execute"
|
||||
|
||||
|
||||
class _UserType(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
@ -66,16 +74,18 @@ User: TypeAlias = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class ChatflowExecutionParams(BaseModel):
|
||||
class AppExecutionParams(BaseModel):
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
tenant_id: str
|
||||
app_mode: AppMode = AppMode.ADVANCED_CHAT
|
||||
user: User
|
||||
args: Mapping[str, Any]
|
||||
|
||||
invoke_from: InvokeFrom
|
||||
streaming: bool = True
|
||||
call_depth: int = 0
|
||||
root_node_id: str | None = None
|
||||
workflow_run_id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
||||
|
||||
@classmethod
|
||||
@ -87,6 +97,9 @@ class ChatflowExecutionParams(BaseModel):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
root_node_id: str | None = None,
|
||||
workflow_run_id: uuid.UUID | None = None,
|
||||
):
|
||||
user_params: _Account | _EndUser
|
||||
if isinstance(user, Account):
|
||||
@ -99,16 +112,19 @@ class ChatflowExecutionParams(BaseModel):
|
||||
app_id=app_model.id,
|
||||
workflow_id=workflow.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
user=user_params,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
workflow_run_id=uuid.uuid4(),
|
||||
call_depth=call_depth,
|
||||
root_node_id=root_node_id,
|
||||
workflow_run_id=workflow_run_id or uuid.uuid4(),
|
||||
)
|
||||
|
||||
|
||||
class _ChatflowRunner:
|
||||
def __init__(self, session_factory: sessionmaker | Engine, exec_params: ChatflowExecutionParams):
|
||||
class _AppRunner:
|
||||
def __init__(self, session_factory: sessionmaker | Engine, exec_params: AppExecutionParams):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
@ -130,7 +146,13 @@ class _ChatflowRunner:
|
||||
exec_params = self._exec_params
|
||||
with self._session() as session:
|
||||
workflow = session.get(Workflow, exec_params.workflow_id)
|
||||
if workflow is None:
|
||||
logger.warning("Workflow %s not found for execution", exec_params.workflow_id)
|
||||
return None
|
||||
app = session.get(App, workflow.app_id)
|
||||
if app is None:
|
||||
logger.warning("App %s not found for workflow %s", workflow.app_id, exec_params.workflow_id)
|
||||
return None
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=self._session_factory,
|
||||
@ -139,25 +161,54 @@ class _ChatflowRunner:
|
||||
|
||||
user = self._resolve_user()
|
||||
|
||||
chat_generator = AdvancedChatAppGenerator()
|
||||
|
||||
workflow_run_id = exec_params.workflow_run_id
|
||||
|
||||
with self._setup_flask_context(user):
|
||||
response = chat_generator.generate(
|
||||
response = self._run_app(
|
||||
app=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
if not exec_params.streaming:
|
||||
return response
|
||||
|
||||
_publish_streaming_response(response, exec_params.workflow_run_id, exec_params.app_mode)
|
||||
|
||||
def _run_app(
|
||||
self,
|
||||
*,
|
||||
app: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
pause_state_config: PauseStateLayerConfig,
|
||||
):
|
||||
exec_params = self._exec_params
|
||||
if exec_params.app_mode == AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=exec_params.args,
|
||||
invoke_from=exec_params.invoke_from,
|
||||
streaming=exec_params.streaming,
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_state_config=pause_config,
|
||||
workflow_run_id=exec_params.workflow_run_id,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
if exec_params.app_mode == AppMode.WORKFLOW:
|
||||
return WorkflowAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=exec_params.args,
|
||||
invoke_from=exec_params.invoke_from,
|
||||
streaming=exec_params.streaming,
|
||||
call_depth=exec_params.call_depth,
|
||||
root_node_id=exec_params.root_node_id,
|
||||
workflow_run_id=exec_params.workflow_run_id,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
if not exec_params.streaming:
|
||||
return response
|
||||
|
||||
_publish_streaming_response(response, workflow_run_id)
|
||||
logger.error("Unsupported app mode for execution: %s", exec_params.app_mode)
|
||||
return None
|
||||
|
||||
def _resolve_user(self) -> Account | EndUser:
|
||||
user_params = self._exec_params.user
|
||||
@ -199,13 +250,13 @@ def _coerce_uuid(value: Any) -> uuid.UUID | None:
|
||||
return None
|
||||
|
||||
|
||||
def _publish_streaming_response(response_stream: Iterable[Any], workflow_run_id: Any) -> None:
|
||||
def _publish_streaming_response(response_stream: Iterable[Any], workflow_run_id: Any, app_mode: AppMode) -> None:
|
||||
workflow_run_uuid = _coerce_uuid(workflow_run_id)
|
||||
if workflow_run_uuid is None:
|
||||
logger.warning("Unable to publish streaming response without valid workflow_run_id: %s", workflow_run_id)
|
||||
return
|
||||
|
||||
topic = AdvancedChatAppGenerator.get_response_topic(AppMode.ADVANCED_CHAT, workflow_run_uuid)
|
||||
topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_uuid)
|
||||
for event in response_stream:
|
||||
try:
|
||||
payload = json.dumps(event)
|
||||
@ -216,18 +267,17 @@ def _publish_streaming_response(response_stream: Iterable[Any], workflow_run_id:
|
||||
topic.publish(payload.encode())
|
||||
|
||||
|
||||
@shared_task(queue="chatflow_execute")
|
||||
@shared_task(queue=APP_EXECUTE_QUEUE)
|
||||
def chatflow_execute_task(payload: str) -> Mapping[str, Any] | None:
|
||||
exec_params = ChatflowExecutionParams.model_validate_json(payload)
|
||||
exec_params = AppExecutionParams.model_validate_json(payload)
|
||||
|
||||
logger.info("chatflow_execute_task run with params: %s", exec_params)
|
||||
|
||||
runner = _ChatflowRunner(db.engine, exec_params=exec_params)
|
||||
runner = _AppRunner(db.engine, exec_params=exec_params)
|
||||
return runner.run()
|
||||
|
||||
|
||||
@shared_task(queue="chatflow_execute", name="resume_chatflow_execution")
|
||||
def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
def _resume_app_execution(payload: dict[str, Any]) -> None:
|
||||
workflow_run_id = payload["workflow_run_id"]
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
@ -245,16 +295,11 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
return
|
||||
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if not isinstance(generate_entity, AdvancedChatAppGenerateEntity):
|
||||
logger.error(
|
||||
"Resumption entity is not AdvancedChatAppGenerateEntity for workflow run %s (found %s)",
|
||||
workflow_run_id,
|
||||
type(generate_entity),
|
||||
)
|
||||
return
|
||||
|
||||
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
|
||||
conversation = None
|
||||
message = None
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is None:
|
||||
@ -271,29 +316,38 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
logger.warning("App %s not found during resume", workflow_run.app_id)
|
||||
return
|
||||
|
||||
if generate_entity.conversation_id is None:
|
||||
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
conversation = session.get(Conversation, generate_entity.conversation_id)
|
||||
if conversation is None:
|
||||
logger.warning(
|
||||
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
|
||||
)
|
||||
return
|
||||
|
||||
message = session.scalar(
|
||||
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
|
||||
)
|
||||
if message is None:
|
||||
logger.warning("Message not found for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
user = _resolve_user_for_run(session, workflow_run)
|
||||
if user is None:
|
||||
logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id)
|
||||
return
|
||||
|
||||
if isinstance(generate_entity, AdvancedChatAppGenerateEntity):
|
||||
if generate_entity.conversation_id is None:
|
||||
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
conversation = session.get(Conversation, generate_entity.conversation_id)
|
||||
if conversation is None:
|
||||
logger.warning(
|
||||
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
|
||||
)
|
||||
return
|
||||
|
||||
message = session.scalar(
|
||||
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
|
||||
)
|
||||
if message is None:
|
||||
logger.warning("Message not found for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
if not isinstance(generate_entity, (AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity)):
|
||||
logger.error(
|
||||
"Unsupported resumption entity for workflow run %s (found %s)",
|
||||
workflow_run_id,
|
||||
type(generate_entity),
|
||||
)
|
||||
return
|
||||
|
||||
workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
@ -301,6 +355,52 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
if isinstance(generate_entity, AdvancedChatAppGenerateEntity):
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
_resume_advanced_chat(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
session_factory=session_factory,
|
||||
pause_state_config=pause_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run=workflow_run,
|
||||
)
|
||||
elif isinstance(generate_entity, WorkflowAppGenerateEntity):
|
||||
_resume_workflow(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
session_factory=session_factory,
|
||||
pause_state_config=pause_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run=workflow_run,
|
||||
workflow_run_repo=workflow_run_repo,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
|
||||
def _resume_advanced_chat(
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
generate_entity: AdvancedChatAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
session_factory: sessionmaker,
|
||||
pause_state_config: PauseStateLayerConfig,
|
||||
workflow_run_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> None:
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
@ -332,7 +432,7 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=pause_config,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
|
||||
@ -346,4 +446,76 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
workflow_run_id,
|
||||
)
|
||||
else:
|
||||
_publish_streaming_response(response, publish_uuid)
|
||||
_publish_streaming_response(response, publish_uuid, AppMode.ADVANCED_CHAT)
|
||||
|
||||
|
||||
def _resume_workflow(
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
generate_entity: WorkflowAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
session_factory: sessionmaker,
|
||||
pause_state_config: PauseStateLayerConfig,
|
||||
workflow_run_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
workflow_run_repo,
|
||||
pause_entity,
|
||||
) -> None:
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
try:
|
||||
response = generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
if generate_entity.stream:
|
||||
publish_uuid = _coerce_uuid(generate_entity.workflow_execution_id) or _coerce_uuid(workflow_run_id)
|
||||
if publish_uuid is None:
|
||||
logger.warning(
|
||||
"Unable to publish streaming response for workflow run %s due to missing workflow_run_id",
|
||||
workflow_run_id,
|
||||
)
|
||||
else:
|
||||
_publish_streaming_response(response, publish_uuid, AppMode.WORKFLOW)
|
||||
|
||||
workflow_run_repo.delete_workflow_pause(pause_entity)
|
||||
|
||||
|
||||
@shared_task(queue=APP_EXECUTE_QUEUE, name="resume_app_execution")
|
||||
def resume_app_execution(payload: dict[str, Any]) -> None:
|
||||
_resume_app_execution(payload)
|
||||
|
||||
|
||||
@shared_task(queue=APP_EXECUTE_QUEUE, name="resume_chatflow_execution")
|
||||
def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
_resume_app_execution(payload)
|
||||
|
||||
@ -26,7 +26,7 @@ from models.account import Account
|
||||
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser, Tenant
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import WorkflowNotFoundError
|
||||
@ -40,12 +40,6 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRIGGER_TO_RUN_SOURCE = {
|
||||
AppTriggerType.TRIGGER_WEBHOOK: WorkflowRunTriggeredFrom.WEBHOOK,
|
||||
AppTriggerType.TRIGGER_SCHEDULE: WorkflowRunTriggeredFrom.SCHEDULE,
|
||||
AppTriggerType.TRIGGER_PLUGIN: WorkflowRunTriggeredFrom.PLUGIN,
|
||||
}
|
||||
|
||||
|
||||
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
|
||||
def execute_workflow_professional(task_data_dict: dict[str, Any]):
|
||||
@ -204,44 +198,135 @@ def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory)
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id)
|
||||
if pause_entity is None:
|
||||
logger.warning("No pause state for workflow run %s", task_data.workflow_run_id)
|
||||
return
|
||||
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(pause_entity.workflow_execution_id)
|
||||
if workflow_run is None:
|
||||
logger.warning("Workflow run not found for pause entity: pause_entity_id=%s", pause_entity.id)
|
||||
return
|
||||
|
||||
try:
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id)
|
||||
raise exc
|
||||
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if not isinstance(generate_entity, WorkflowAppGenerateEntity):
|
||||
logger.error(
|
||||
"Unsupported resumption entity for workflow run %s: %s",
|
||||
task_data.workflow_run_id,
|
||||
type(generate_entity),
|
||||
)
|
||||
return
|
||||
|
||||
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
|
||||
with session_factory() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
|
||||
if not trigger_log:
|
||||
logger.warning("Trigger log not found for resumption: %s", task_data.workflow_trigger_log_id)
|
||||
return
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id)
|
||||
if pause_entity is None:
|
||||
logger.warning("No pause state for workflow run %s", task_data.workflow_run_id)
|
||||
return
|
||||
|
||||
try:
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id)
|
||||
raise exc
|
||||
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if not isinstance(generate_entity, WorkflowAppGenerateEntity):
|
||||
logger.error(
|
||||
"Unsupported resumption entity for workflow run %s: %s",
|
||||
task_data.workflow_run_id,
|
||||
type(generate_entity),
|
||||
)
|
||||
return
|
||||
|
||||
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
|
||||
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
|
||||
workflow = session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if workflow is None:
|
||||
raise WorkflowNotFoundError(f"Workflow not found: {trigger_log.workflow_id}")
|
||||
|
||||
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
|
||||
raise WorkflowNotFoundError(
|
||||
"Workflow not found: workflow_run_id=%s, workflow_id=%s", workflow_run.id, workflow_run.workflow_id
|
||||
)
|
||||
user = _get_user(session, workflow_run)
|
||||
app_model = session.scalar(select(App).where(App.id == workflow_run.app_id))
|
||||
if app_model is None:
|
||||
raise WorkflowNotFoundError(f"App not found: {trigger_log.app_id}")
|
||||
raise _AppNotFoundError(
|
||||
"App not found: app_id=%s, workflow_run_id=%s", workflow_run.app_id, workflow_run.id
|
||||
)
|
||||
|
||||
user = _get_user(session, trigger_log)
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom(workflow_run.triggered_from),
|
||||
)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
start_time = datetime.now(UTC)
|
||||
graph_engine_layers = []
|
||||
trigger_log = _query_trigger_log_info(session_factory, task_data.workflow_run_id)
|
||||
|
||||
if trigger_log:
|
||||
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
|
||||
queue=AsyncWorkflowQueue(trigger_log.queue_name),
|
||||
schedule_strategy=AsyncWorkflowSystemStrategy,
|
||||
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
|
||||
)
|
||||
cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity)
|
||||
|
||||
graph_engine_layers.extend(
|
||||
[
|
||||
TimeSliceLayer(cfs_plan_scheduler),
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
|
||||
]
|
||||
)
|
||||
|
||||
workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity)
|
||||
|
||||
generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
workflow_run_repo.delete_workflow_pause(pause_entity)
|
||||
|
||||
|
||||
def _get_user(session: Session, workflow_run: WorkflowRun) -> Account | EndUser:
|
||||
"""Compose user from trigger log"""
|
||||
tenant = session.scalar(select(Tenant).where(Tenant.id == workflow_run.tenant_id))
|
||||
if not tenant:
|
||||
raise _TenantNotFoundError(
|
||||
"Tenant not found for WorkflowRun: tenant_id=%s, workflow_run_id=%s",
|
||||
workflow_run.tenant_id,
|
||||
workflow_run.id,
|
||||
)
|
||||
|
||||
# Get user from trigger log
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
user = session.scalar(select(Account).where(Account.id == workflow_run.created_by))
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
else: # CreatorUserRole.END_USER
|
||||
user = session.scalar(select(EndUser).where(EndUser.id == workflow_run.created_by))
|
||||
|
||||
if not user:
|
||||
raise _UserNotFoundError(
|
||||
"User not found: user_id=%s, created_by_role=%s, workflow_run_id=%s",
|
||||
workflow_run.created_by,
|
||||
workflow_run.created_by_role,
|
||||
workflow_run.id,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def _query_trigger_log_info(session_factory: sessionmaker[Session], workflow_run_id) -> WorkflowTriggerLog | None:
|
||||
with session_factory() as session, session.begin():
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id)
|
||||
if not trigger_log:
|
||||
logger.debug("Trigger log not found for workflow_run: workflow_run_id=%s", workflow_run_id)
|
||||
return
|
||||
|
||||
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
|
||||
queue=trigger_log.queue_name,
|
||||
@ -255,74 +340,14 @@ def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None:
|
||||
except ValueError:
|
||||
trigger_type = AppTriggerType.UNKNOWN
|
||||
|
||||
triggered_from = _TRIGGER_TO_RUN_SOURCE.get(trigger_type, WorkflowRunTriggeredFrom.APP_RUN)
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=generate_entity.app_config.app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity)
|
||||
|
||||
trigger_log.status = WorkflowTriggerStatus.RUNNING
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=[
|
||||
TimeSliceLayer(cfs_plan_scheduler),
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
|
||||
],
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
except Exception as exc:
|
||||
trigger_log.status = WorkflowTriggerStatus.FAILED
|
||||
trigger_log.error = str(exc)
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
raise
|
||||
class _TenantNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
|
||||
"""Compose user from trigger log"""
|
||||
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant not found: {trigger_log.tenant_id}")
|
||||
class _UserNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
# Get user from trigger log
|
||||
if trigger_log.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
user = session.scalar(select(Account).where(Account.id == trigger_log.created_by))
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
else: # CreatorUserRole.END_USER
|
||||
user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by))
|
||||
|
||||
if not user:
|
||||
raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})")
|
||||
|
||||
return user
|
||||
class _AppNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import wraps as console_wraps
|
||||
from controllers.console.app import workflow as workflow_module
|
||||
from controllers.console.app import wraps as app_wraps
|
||||
from libs import login as login_lib
|
||||
from models.account import Account, AccountStatus, TenantAccountRole
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(name="tester", email="tester@example.com")
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.role = TenantAccountRole.OWNER
|
||||
account.id = "account-123" # type: ignore[assignment]
|
||||
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
|
||||
account._get_current_object = lambda: account # type: ignore[attr-defined]
|
||||
return account
|
||||
|
||||
|
||||
def _make_app(mode: AppMode) -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value)
|
||||
|
||||
|
||||
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None:
|
||||
# Skip setup and auth guardrails
|
||||
monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD")
|
||||
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
|
||||
monkeypatch.setattr(login_lib, "current_user", account)
|
||||
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
|
||||
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
|
||||
# Avoid hitting the database when resolving the app model
|
||||
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreviewCase:
|
||||
resource_cls: type
|
||||
path: str
|
||||
mode: AppMode
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
PreviewCase(
|
||||
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormApi,
|
||||
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form",
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
),
|
||||
PreviewCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputFormApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form",
|
||||
mode=AppMode.WORKFLOW,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_human_input_preview_delegates_to_service(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase
|
||||
) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(case.mode)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
preview_payload = {
|
||||
"form_id": "node-42",
|
||||
"form_content": "<div>example</div>",
|
||||
"inputs": [{"name": "topic"}],
|
||||
"actions": [{"id": "continue"}],
|
||||
}
|
||||
service_instance = MagicMock()
|
||||
service_instance.get_human_input_form_preview.return_value = preview_payload
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(case.path, method="GET", json={"inputs": {"topic": "tech"}}):
|
||||
response = case.resource_cls().get(app_id=app_model.id, node_id="node-42")
|
||||
|
||||
assert response == preview_payload
|
||||
service_instance.get_human_input_form_preview.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-42",
|
||||
manual_inputs={"topic": "tech"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubmitCase:
|
||||
resource_cls: type
|
||||
path: str
|
||||
mode: AppMode
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
SubmitCase(
|
||||
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormApi,
|
||||
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-99/form",
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
),
|
||||
SubmitCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputFormApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-99/form",
|
||||
mode=AppMode.WORKFLOW,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_human_input_submit_forwards_payload(app: Flask, monkeypatch: pytest.MonkeyPatch, case: SubmitCase) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(case.mode)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
result_payload = {"node_id": "node-99", "outputs": {"__rendered_content": "<p>done</p>"}, "action": "approve"}
|
||||
service_instance = MagicMock()
|
||||
service_instance.submit_human_input_form_preview.return_value = result_payload
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(
|
||||
case.path,
|
||||
method="POST",
|
||||
json={"inputs": {"answer": "42"}, "action": "approve"},
|
||||
):
|
||||
response = case.resource_cls().post(app_id=app_model.id, node_id="node-99")
|
||||
|
||||
assert response == result_payload
|
||||
service_instance.submit_human_input_form_preview.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-99",
|
||||
form_inputs={"answer": "42"},
|
||||
action="approve",
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(AppMode.ADVANCED_CHAT)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form",
|
||||
method="GET",
|
||||
json={"inputs": ["not-a-dict"]},
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
workflow_module.AdvancedChatDraftHumanInputFormApi().get(app_id=app_model.id, node_id="node-1")
|
||||
@ -124,7 +124,12 @@ class TestWorkflowResponseConverter:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
is_resumption=False,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -160,7 +165,12 @@ class TestWorkflowResponseConverter:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
is_resumption=False,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -191,7 +201,12 @@ class TestWorkflowResponseConverter:
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
is_resumption=False,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -225,7 +240,12 @@ class TestWorkflowResponseConverter:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
is_resumption=False,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -261,7 +281,12 @@ class TestWorkflowResponseConverter:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
is_resumption=False,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -400,6 +425,7 @@ class TestWorkflowResponseConverterServiceApiTruncation:
|
||||
task_id="test-task-id",
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
is_resumption=False,
|
||||
)
|
||||
return converter
|
||||
|
||||
|
||||
@ -112,7 +112,12 @@ def _build_converter():
|
||||
|
||||
def test_queue_workflow_paused_event_to_stream_responses():
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(task_id="task", workflow_run_id="run-id", workflow_id="workflow-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
is_resumption=False,
|
||||
)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
|
||||
@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation:
|
||||
within conversations.
|
||||
"""
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_by_first_id_without_first_id(
|
||||
self, mock_get_conversation, mock_db_session, mock_create_extra_repo
|
||||
):
|
||||
"""
|
||||
Test message pagination without specifying first_id.
|
||||
|
||||
@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act - Call the pagination method without first_id
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation:
|
||||
# Verify conversation was looked up with correct parameters
|
||||
mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test message pagination with first_id specified.
|
||||
|
||||
@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.first.return_value = first_message # First message returned
|
||||
mock_query.all.return_value = messages # Remaining messages returned
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act - Call the pagination method with first_id
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation:
|
||||
assert result.data == []
|
||||
assert result.has_more is False
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test that has_more flag is correctly set when there are more messages.
|
||||
|
||||
@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation:
|
||||
assert len(result.data) == limit # Extra message should be removed
|
||||
assert result.has_more is True # Flag should be set
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test message pagination with ascending order.
|
||||
|
||||
@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
|
||||
@ -65,72 +65,25 @@ def sample_form_record():
|
||||
)
|
||||
|
||||
|
||||
def test_enqueue_resume_dispatches_task(mocker, mock_session_factory):
|
||||
def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
trigger_log = MagicMock()
|
||||
trigger_log.id = "trigger-log-id"
|
||||
trigger_log.queue_name = "workflow_queue"
|
||||
|
||||
repo_cls = mocker.patch(
|
||||
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
autospec=True,
|
||||
)
|
||||
repo = repo_cls.return_value
|
||||
repo.get_by_workflow_run_id.return_value = trigger_log
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
repo_cls.assert_called_once_with(session)
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == "workflow_queue"
|
||||
payload = call_kwargs["kwargs"]["task_data_dict"]
|
||||
assert payload["workflow_trigger_log_id"] == "trigger-log-id"
|
||||
assert payload["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_enqueue_resume_no_trigger_log(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
repo_cls = mocker.patch(
|
||||
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
autospec=True,
|
||||
)
|
||||
repo = repo_cls.return_value
|
||||
repo.get_by_workflow_run_id.return_value = None
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
repo_cls.assert_called_once_with(session)
|
||||
resume_task.apply_async.assert_not_called()
|
||||
|
||||
|
||||
def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
repo_cls = mocker.patch(
|
||||
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
autospec=True,
|
||||
)
|
||||
repo = repo_cls.return_value
|
||||
repo.get_by_workflow_run_id.return_value = None
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "advanced-chat"
|
||||
app.mode = "workflow"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
session.get.side_effect = [workflow_run, app]
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_chatflow_execution")
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
@ -140,6 +93,59 @@ def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "advanced-chat"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == "chatflow_execute"
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "completion"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_not_called()
|
||||
|
||||
|
||||
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
|
||||
Reference in New Issue
Block a user