feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@ -4,8 +4,8 @@ import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -29,21 +29,25 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@ -65,7 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: Literal[False],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@ -74,9 +80,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: Literal[True],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping | str, None, None]: ...
@overload
@ -85,9 +93,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: bool,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate(
@ -95,9 +105,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: bool = True,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
"""
Generate App response.
@ -161,7 +173,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
workflow_run_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
@ -179,7 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
workflow_run_id=workflow_run_id,
workflow_run_id=str(workflow_run_id),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -216,6 +227,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
pause_state_config=pause_state_config,
)
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
conversation: Conversation,
message: Message,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_runtime_state: GraphRuntimeState,
pause_state_config: PauseStateLayerConfig | None = None,
):
"""
Resume a paused advanced chat execution.
"""
return self._generate(
workflow=workflow,
user=user,
invoke_from=application_generate_entity.invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
message=message,
stream=application_generate_entity.stream,
pause_state_config=pause_state_config,
graph_runtime_state=graph_runtime_state,
)
def single_iteration_generate(
@ -396,8 +439,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Conversation | None = None,
message: Message | None = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
pause_state_config: PauseStateLayerConfig | None = None,
graph_runtime_state: GraphRuntimeState | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@ -411,12 +458,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
is_first_conversation = conversation is None
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if conversation is not None and message is not None:
pass
else:
conversation, message = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
@ -439,6 +486,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id,
)
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@ -454,14 +511,25 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
# message_ = session.get(Message, message.id)
# assert message_ is not None
# message = message_
# db.session.refresh(workflow)
# db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
@ -490,6 +558,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
"""
Generate worker in a new thread.
@ -547,6 +617,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app=app,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:
@ -614,3 +686,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e
_T = TypeVar("_T", bound=Base)
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model

View File

@ -66,6 +66,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@ -82,6 +83,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._app = app
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@ -110,7 +112,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
invoke_from=invoke_from,
user_from=user_from,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,

View File

@ -24,6 +24,8 @@ from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -42,6 +44,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@ -63,6 +66,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@ -71,7 +76,8 @@ from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -128,6 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._task_state = WorkflowTaskState()
self._seed_task_state_from_message(message)
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
@ -135,6 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow_tenant_id = workflow.tenant_id
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
@ -144,8 +152,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
self._message_saved_on_pause = False
self._seed_graph_runtime_state_from_queue_manager()
def _seed_task_state_from_message(self, message: Message) -> None:
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
@ -308,6 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow_id,
reason=event.reason,
)
yield workflow_start_resp
@ -525,6 +539,35 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
yield from responses
resolved_state: GraphRuntimeState | None = None
try:
resolved_state = self._ensure_graph_runtime_initialized()
except ValueError:
resolved_state = None
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
message = self._get_message(session=session)
if message is not None:
message.status = MessageStatus.PAUSED
self._message_saved_on_pause = True
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_failed_event(
@ -614,9 +657,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
# Save message unless it has already been persisted on pause.
if not self._message_saved_on_pause:
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@ -642,6 +686,65 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""Handle message replace events."""
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
def _handle_human_input_form_filled_event(
self, event: QueueHumanInputFormFilledEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form filled events."""
self._persist_human_input_extra_content(node_id=event.node_id)
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _handle_human_input_form_timeout_event(
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form timeout events."""
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
if not self._workflow_run_id or not self._message_id:
return
if form_id is None:
if node_id is None:
return
form_id = self._load_human_input_form_id(node_id=node_id)
if form_id is None:
logger.warning(
"HumanInput form not found for workflow run %s node %s",
self._workflow_run_id,
node_id,
)
return
with self._database_session() as session:
exists_stmt = select(HumanInputContent).where(
HumanInputContent.workflow_run_id == self._workflow_run_id,
HumanInputContent.message_id == self._message_id,
HumanInputContent.form_id == form_id,
)
if session.scalar(exists_stmt) is not None:
return
content = HumanInputContent(
workflow_run_id=self._workflow_run_id,
message_id=self._message_id,
form_id=form_id,
)
session.add(content)
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self._workflow_tenant_id,
)
form = form_repository.get_form(self._workflow_run_id, node_id)
if form is None:
return None
return form.id
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
@ -659,6 +762,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
@ -680,6 +784,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueMessageReplaceEvent: self._handle_message_replace_event,
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
QueueAgentLogEvent: self._handle_agent_log_event,
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
}
def _dispatch_event(
@ -747,6 +853,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
@ -772,6 +881,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
if message is None:
return
if message.status == MessageStatus.PAUSED:
message.status = MessageStatus.NORMAL
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer

View File

@ -5,9 +5,14 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -19,9 +24,13 @@ from core.app.entities.queue_entities import (
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueWorkflowPausedEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
HumanInputFormFilledResponse,
HumanInputFormTimeoutResponse,
HumanInputRequiredResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
@ -31,7 +40,9 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
StreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
@ -40,6 +51,8 @@ from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@ -51,8 +64,11 @@ from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.human_input import HumanInputForm
from models.workflow import WorkflowRun
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@ -191,6 +207,7 @@ class WorkflowResponseConverter:
task_id: str,
workflow_run_id: str,
workflow_id: str,
reason: WorkflowStartReason,
) -> WorkflowStartStreamResponse:
run_id = self._ensure_workflow_run_id(workflow_run_id)
started_at = naive_utc_now()
@ -204,6 +221,7 @@ class WorkflowResponseConverter:
workflow_id=workflow_id,
inputs=self._workflow_inputs,
created_at=int(started_at.timestamp()),
reason=reason,
),
)
@ -264,6 +282,160 @@ class WorkflowResponseConverter:
),
)
def workflow_pause_to_stream_response(
self,
*,
event: QueueWorkflowPausedEvent,
task_id: str,
graph_runtime_state: GraphRuntimeState,
) -> list[StreamResponse]:
run_id = self._ensure_workflow_run_id()
started_at = self._workflow_started_at
if started_at is None:
raise ValueError(
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
)
paused_at = naive_utc_now()
elapsed_time = (paused_at - started_at).total_seconds()
encoded_outputs = self._encode_outputs(event.outputs) or {}
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
encoded_outputs = {}
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
expiration_times_by_form_id: dict[str, datetime] = {}
if human_input_form_ids:
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
HumanInputForm.id.in_(human_input_form_ids)
)
with Session(bind=db.engine) as session:
for form_id, expiration_time in session.execute(stmt):
expiration_times_by_form_id[str(form_id)] = expiration_time
responses: list[StreamResponse] = []
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
expiration_time = expiration_times_by_form_id.get(reason.form_id)
if expiration_time is None:
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
responses.append(
HumanInputRequiredResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputRequiredResponse.Data(
form_id=reason.form_id,
node_id=reason.node_id,
node_title=reason.node_title,
form_content=reason.form_content,
inputs=reason.inputs,
actions=reason.actions,
display_in_ui=reason.display_in_ui,
form_token=reason.form_token,
resolved_default_values=reason.resolved_default_values,
expiration_time=int(expiration_time.timestamp()),
),
)
)
responses.append(
WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowPauseStreamResponse.Data(
workflow_run_id=run_id,
paused_nodes=list(event.paused_nodes),
outputs=encoded_outputs,
reasons=pause_reasons,
status=WorkflowExecutionStatus.PAUSED,
created_at=int(started_at.timestamp()),
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
),
)
)
return responses
def human_input_form_filled_to_stream_response(
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
) -> HumanInputFormFilledResponse:
run_id = self._ensure_workflow_run_id()
return HumanInputFormFilledResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
),
)
def human_input_form_timeout_to_stream_response(
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
) -> HumanInputFormTimeoutResponse:
run_id = self._ensure_workflow_run_id()
return HumanInputFormTimeoutResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormTimeoutResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
expiration_time=int(event.expiration_time.timestamp()),
),
)
@classmethod
def workflow_run_result_to_finish_response(
cls,
*,
task_id: str,
workflow_run: WorkflowRun,
creator_user: Account | EndUser,
) -> WorkflowFinishStreamResponse:
run_id = workflow_run.id
elapsed_time = workflow_run.elapsed_time
encoded_outputs = workflow_run.outputs_dict
finished_at = workflow_run.finished_at
assert finished_at is not None
created_by: Mapping[str, object]
user = creator_user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status,
outputs=encoded_outputs,
error=workflow_run.error,
elapsed_time=elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=cls.fetch_files_from_node_outputs(encoded_outputs),
exceptions_count=workflow_run.exceptions_count,
),
)
def workflow_node_start_to_stream_response(
self,
*,
@ -592,7 +764,8 @@ class WorkflowResponseConverter:
),
)
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
@classmethod
def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@ -601,7 +774,7 @@ class WorkflowResponseConverter:
if not outputs_dict:
return []
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list

View File

@ -1,6 +1,6 @@
import json
import logging
from collections.abc import Generator
from collections.abc import Callable, Generator, Mapping
from typing import Union, cast
from sqlalchemy import select
@ -10,12 +10,14 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
AppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
ConversationAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import (
@ -27,6 +29,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
@ -156,6 +160,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
created_new_conversation = conversation is None
try:
if not conversation:
conversation = Conversation(
@ -232,6 +237,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add_all(message_files)
db.session.commit()
if isinstance(application_generate_entity, ConversationAppGenerateEntity):
application_generate_entity.conversation_id = conversation.id
application_generate_entity.is_new_conversation = created_new_conversation
return conversation, message
except Exception:
db.session.rollback()
@ -284,3 +293,29 @@ class MessageBasedAppGenerator(BaseAppGenerator):
raise MessageNotExistsError("Message not exists")
return message
@staticmethod
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
return f"channel:{app_mode}:{workflow_run_id}"
@classmethod
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
key = cls._make_channel_key(app_mode, workflow_run_id)
channel = get_pubsub_broadcast_channel()
topic = channel.topic(key)
return topic
@classmethod
def retrieve_events(
cls,
app_mode: AppMode,
workflow_run_id: str,
idle_timeout=300,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
topic=topic,
idle_timeout=idle_timeout,
on_subscribe=on_subscribe,
)

View File

@ -0,0 +1,36 @@
from collections.abc import Callable, Generator, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
class MessageGenerator:
@staticmethod
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
return f"channel:{app_mode}:{str(workflow_run_id)}"
@classmethod
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
key = cls._make_channel_key(app_mode, workflow_run_id)
channel = get_pubsub_broadcast_channel()
topic = channel.topic(key)
return topic
@classmethod
def retrieve_events(
cls,
app_mode: AppMode,
workflow_run_id: str,
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
topic=topic,
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
)

View File

@ -0,0 +1,70 @@
from __future__ import annotations
import json
import time
from collections.abc import Callable, Generator, Iterable, Mapping
from typing import Any
from core.app.entities.task_entities import StreamEvent
from libs.broadcast_channel.channel import Topic
from libs.broadcast_channel.exc import SubscriptionClosedError
def stream_topic_events(
*,
topic: Topic,
idle_timeout: float,
ping_interval: float | None = None,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]:
# send a PING event immediately to prevent the connection staying in pending state for a long time.
#
# This simplify the debugging process as the DevTools in Chrome does not
# provide complete curl command for pending connections.
yield StreamEvent.PING.value
terminal_values = _normalize_terminal_events(terminal_events)
last_msg_time = time.time()
last_ping_time = last_msg_time
with topic.subscribe() as sub:
# on_subscribe fires only after the Redis subscription is active.
# This is used to gate task start and reduce pub/sub race for the first event.
if on_subscribe is not None:
on_subscribe()
while True:
try:
msg = sub.receive(timeout=0.1)
except SubscriptionClosedError:
return
if msg is None:
current_time = time.time()
if current_time - last_msg_time > idle_timeout:
return
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
yield StreamEvent.PING.value
last_ping_time = current_time
continue
last_msg_time = time.time()
last_ping_time = last_msg_time
event = json.loads(msg)
yield event
if not isinstance(event, dict):
continue
event_type = event.get("event")
if event_type in terminal_values:
return
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:
if isinstance(item, StreamEvent):
values.add(item.value)
else:
values.add(str(item))
return values

View File

@ -25,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
@ -34,12 +35,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.account import Account
from models.enums import WorkflowRunTriggeredFrom
from models.model import App, EndUser
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
if TYPE_CHECKING:
@ -66,9 +70,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]: ...
@overload
@ -82,9 +88,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@ -98,9 +106,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
@ -113,9 +123,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@ -150,7 +162,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
@ -216,13 +228,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_state_config,
)
def resume(self, *, workflow_run_id: str) -> None:
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
@TBD
Resume a paused workflow execution using the persisted runtime state.
"""
pass
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=application_generate_entity.invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=application_generate_entity.stream,
variable_loader=variable_loader,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
pause_state_config=pause_state_config,
)
def _generate(
self,
@ -238,6 +277,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -251,6 +292,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
"""
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@ -259,6 +302,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_mode=app_model.mode,
)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@ -276,7 +328,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": graph_engine_layers,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
@ -378,6 +431,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def single_loop_generate(
@ -459,6 +513,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def _generate_worker(
@ -472,6 +527,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
) -> None:
"""
Generate worker in a new thread.
@ -517,6 +573,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:

View File

@ -42,6 +42,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@ -55,6 +56,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._root_node_id = root_node_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@ -63,23 +65,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
invoke_from = self.application_generate_entity.invoke_from
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
root_node_id=self._root_node_id,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
@ -89,7 +96,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
inputs = self.application_generate_entity.inputs
# Create a variable pool.
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
@ -98,8 +112,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,

View File

@ -0,0 +1,7 @@
from libs.exception import BaseHTTPException
class WorkflowPausedInBlockingModeError(BaseHTTPException):
error_code = "workflow_paused_in_blocking_mode"
description = "Workflow execution paused for human input; blocking response mode is not supported."
code = 400

View File

@ -16,6 +16,8 @@ from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -32,6 +34,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@ -46,11 +49,13 @@ from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.runtime import GraphRuntimeState
@ -132,6 +137,25 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
outputs=stream_response.data.outputs or {},
error=None,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@ -146,7 +170,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at),
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
),
)
@ -259,13 +283,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
if event.reason == WorkflowStartReason.INITIAL:
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow.id,
reason=event.reason,
)
yield start_resp
@ -440,6 +466,21 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
yield from responses
def _handle_workflow_failed_and_stop_events(
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
@ -495,6 +536,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id, event=event
)
def _handle_human_input_form_filled_event(
self, event: QueueHumanInputFormFilledEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form filled events."""
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _handle_human_input_form_timeout_event(
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form timeout events."""
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _get_event_handlers(self) -> dict[type, Callable]:
"""Get mapping of event types to their handlers using fluent pattern."""
return {
@ -506,6 +563,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
@ -520,6 +578,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueLoopCompletedEvent: self._handle_loop_completed_event,
# Agent events
QueueAgentLogEvent: self._handle_agent_log_event,
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
}
def _dispatch_event(
@ -602,6 +662,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_workflow_failed_and_stop_events(event)

View File

@ -1,3 +1,4 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
@ -7,6 +8,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -22,22 +25,27 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
@ -61,6 +69,9 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader,
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)
class WorkflowBasedAppRunner:
@ -327,7 +338,7 @@ class WorkflowBasedAppRunner:
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(QueueWorkflowStartedEvent())
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
@ -338,6 +349,38 @@ class WorkflowBasedAppRunner:
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, GraphRunPausedEvent):
runtime_state = workflow_entry.graph_engine.graph_runtime_state
paused_nodes = runtime_state.get_paused_nodes()
self._enqueue_human_input_notifications(event.reasons)
self._publish_event(
QueueWorkflowPausedEvent(
reasons=event.reasons,
outputs=event.outputs,
paused_nodes=paused_nodes,
)
)
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
self._publish_event(
QueueHumanInputFormFilledEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
)
)
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
self._publish_event(
QueueHumanInputFormTimeoutEvent(
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
expiration_time=event.expiration_time,
)
)
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
@ -544,5 +587,19 @@ class WorkflowBasedAppRunner:
)
)
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
for reason in reasons:
if not isinstance(reason, HumanInputRequired):
continue
if not reason.form_id:
continue
try:
dispatch_human_input_email_task.apply_async(
kwargs={"form_id": reason.form_id, "node_title": reason.node_title},
queue="mail",
)
except Exception: # pragma: no cover - defensive logging
logger.exception("Failed to enqueue human input email task for form %s", reason.form_id)
def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional["TraceQueueManager"] = None
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
@ -156,6 +156,7 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
"""
conversation_id: str | None = None
is_new_conversation: bool = False
parent_message_id: str | None = Field(
default=None,
description=(

View File

@ -8,6 +8,8 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@ -46,6 +48,9 @@ class QueueEvent(StrEnum):
PING = "ping"
STOP = "stop"
RETRY = "retry"
PAUSE = "pause"
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
class AppQueueEvent(BaseModel):
@ -261,6 +266,8 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
"""QueueWorkflowStartedEvent entity."""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
# Always present; mirrors GraphRunStartedEvent.reason for downstream consumers.
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
class QueueWorkflowSucceededEvent(AppQueueEvent):
@ -484,6 +491,35 @@ class QueueStopEvent(AppQueueEvent):
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
class QueueHumanInputFormFilledEvent(AppQueueEvent):
"""
QueueHumanInputFormFilledEvent entity
"""
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
rendered_content: str
action_id: str
action_text: str
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
"""
QueueHumanInputFormTimeoutEvent entity
"""
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT
node_id: str
node_type: NodeType
node_title: str
expiration_time: datetime
class QueueMessage(BaseModel):
"""
QueueMessage abstract entity
@ -509,3 +545,14 @@ class WorkflowQueueMessage(QueueMessage):
"""
pass
class QueueWorkflowPausedEvent(AppQueueEvent):
"""
QueueWorkflowPausedEvent entity
"""
event: QueueEvent = QueueEvent.PAUSE
reasons: Sequence[PauseReason] = Field(default_factory=list)
outputs: Mapping[str, object] = Field(default_factory=dict)
paused_nodes: Sequence[str] = Field(default_factory=list)

View File

@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.human_input.entities import FormInput, UserAction
class AnnotationReplyAccount(BaseModel):
@ -69,6 +71,7 @@ class StreamEvent(StrEnum):
AGENT_THOUGHT = "agent_thought"
AGENT_MESSAGE = "agent_message"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_PAUSED = "workflow_paused"
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
@ -82,6 +85,9 @@ class StreamEvent(StrEnum):
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
HUMAN_INPUT_REQUIRED = "human_input_required"
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
class StreamResponse(BaseModel):
@ -205,6 +211,8 @@ class WorkflowStartStreamResponse(StreamResponse):
workflow_id: str
inputs: Mapping[str, Any]
created_at: int
# Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients.
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
workflow_run_id: str
@ -231,7 +239,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
total_steps: int
created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int
finished_at: int
finished_at: int | None
exceptions_count: int | None = 0
files: Sequence[Mapping[str, Any]] | None = []
@ -240,6 +248,85 @@ class WorkflowFinishStreamResponse(StreamResponse):
data: Data
class WorkflowPauseStreamResponse(StreamResponse):
"""
WorkflowPauseStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
workflow_run_id: str
paused_nodes: Sequence[str] = Field(default_factory=list)
outputs: Mapping[str, Any] = Field(default_factory=dict)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
status: WorkflowExecutionStatus
created_at: int
elapsed_time: float
total_tokens: int
total_steps: int
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
workflow_run_id: str
data: Data
class HumanInputRequiredResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int = Field(..., description="Unix timestamp in seconds")
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
workflow_run_id: str
data: Data
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
node_id: str
node_title: str
rendered_content: str
action_id: str
action_text: str
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
workflow_run_id: str
data: Data
class HumanInputFormTimeoutResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
node_id: str
node_title: str
expiration_time: int
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT
workflow_run_id: str
data: Data
class NodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
@ -726,7 +813,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
total_tokens: int
total_steps: int
created_at: int
finished_at: int
finished_at: int | None
workflow_run_id: str
data: Data

View File

@ -1,3 +1,4 @@
import contextlib
import logging
import time
import uuid
@ -103,6 +104,14 @@ class RateLimit:
)
@contextlib.contextmanager
def rate_limit_context(rate_limit: RateLimit, request_id: str | None):
request_id = rate_limit.enter(request_id)
yield
if request_id is not None:
rate_limit.exit(request_id)
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
self.rate_limit = rate_limit

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel):
return self.generate_entity.entity
@dataclass(frozen=True)
class PauseStateLayerConfig:
"""Configuration container for instantiating pause persistence layers."""
session_factory: Engine | sessionmaker[Session]
state_owner_user_id: str
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(
self,

View File

@ -82,10 +82,11 @@ class MessageCycleManager:
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
is_first_message = self._application_generate_entity.is_new_conversation
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
thread: Thread | None = None
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
@ -101,9 +102,10 @@ class MessageCycleManager:
thread.daemon = True
thread.start()
return thread
if is_first_message:
self._application_generate_entity.is_new_conversation = False
return None
return thread
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():