WIP: api debugging

This commit is contained in:
QuantumGhost
2025-11-26 00:33:44 +08:00
parent f368155995
commit dddcf1de6c
31 changed files with 847 additions and 55 deletions

View File

@ -1,3 +1,4 @@
import json
from typing import Literal, cast
from flask import request
@ -409,7 +410,19 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
# Check if workflow is suspended
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
if not is_paused:
return {"is_suspended": False, "paused_at": None, "paused_nodes": [], "pending_human_inputs": []}, 200
return {
"is_suspended": False,
"paused_at": None,
"paused_nodes": [],
"pending_human_inputs": [],
"pause_reasons": [],
}, 200
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
pause_reasons: list[dict[str, object]] = []
if pause_entity:
for reason in pause_entity.get_pause_reasons():
pause_reasons.append(reason.model_dump(mode="json"))
# Get pending Human Input forms for this workflow run
service = HumanInputFormService(db.session())
@ -421,6 +434,7 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
"paused_at": workflow_run.created_at.isoformat() + "Z" if workflow_run.created_at else None,
"paused_nodes": [],
"pending_human_inputs": [],
"pause_reasons": pause_reasons,
}
# Add pending human input forms

View File

@ -2,6 +2,7 @@
Console/Studio Human Input Form APIs.
"""
import json
import logging
from collections.abc import Generator
@ -157,32 +158,18 @@ class ConsoleWorkflowEventsApi(Resource):
if workflow_run.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run.id,
workflow_run=workflow_run,
creator_user=user,
)
# TODO: should we just return here? or yield a WorkflowFinishStreamResponse?
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def generate_events() -> Generator[str, None, None]:
"""Generate SSE events for workflow execution."""
try:
# TODO: Implement actual event streaming
# This would connect to the workflow execution engine
# and stream real-time events
yield f"data: {json.dumps(payload)}\n\n"
# For demo purposes, send a basic event
yield f"data: {{'event': 'workflow_resumed', 'task_id': '{task_id}'}}\n\n"
# In real implementation, this would:
# 1. Connect to workflow execution engine
# 2. Stream real-time execution events
# 3. Handle client disconnection
# 4. Clean up resources on completion
except Exception as e:
logger.exception("Error streaming events for task %s", task_id)
yield f"data: {{'error': 'Stream error: {str(e)}'}}\n\n"
else:
# TODO: SSE from Redis PubSub
msg_generator = MessageGenerator()
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()

View File

@ -42,6 +42,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@ -525,6 +526,19 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
for response in responses:
yield response
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_failed_event(
@ -659,6 +673,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
@ -747,6 +762,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)

View File

@ -19,9 +19,11 @@ from core.app.entities.queue_entities import (
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueWorkflowPausedEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
HumanInputRequiredResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
@ -31,7 +33,9 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
StreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
@ -40,6 +44,7 @@ from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@ -265,10 +270,56 @@ class WorkflowResponseConverter:
),
)
def workflow_pause_to_stream_response(
self,
*,
event: QueueWorkflowPausedEvent,
task_id: str,
) -> list[StreamResponse]:
run_id = self._ensure_workflow_run_id()
encoded_outputs = self._encode_outputs(event.outputs) or {}
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
responses: list[StreamResponse] = []
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
responses.append(
HumanInputRequiredResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputRequiredResponse.Data(
form_id=reason.form_id,
node_id=reason.node_id,
node_title=reason.node_title,
form_content=reason.form_content,
inputs=reason.inputs,
actions=reason.actions,
web_app_form_token=reason.web_app_form_token,
),
)
)
responses.append(
WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowPauseStreamResponse.Data(
workflow_run_id=run_id,
paused_nodes=list(event.paused_nodes),
outputs=encoded_outputs,
reasons=pause_reasons,
),
)
)
return responses
@classmethod
def workflow_run_result_to_finish_response(
cls,
*,
task_id: str,
workflow_run: WorkflowRun,
creator_user: Account | EndUser,
) -> WorkflowFinishStreamResponse:
@ -294,7 +345,7 @@ class WorkflowResponseConverter:
}
return WorkflowFinishStreamResponse(
task_id=task_id, # TODO
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowFinishStreamResponse.Data(
id=run_id,

View File

@ -163,6 +163,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
created_new_conversation = conversation is None
try:
if not conversation:
conversation = Conversation(
@ -239,6 +240,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add_all(message_files)
db.session.commit()
application_generate_entity.conversation_id = conversation.id
application_generate_entity.is_new_conversation = created_new_conversation
return conversation, message
except Exception:
db.session.rollback()

View File

@ -32,6 +32,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@ -440,6 +441,19 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
self._ensure_workflow_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
yield from response
def _handle_workflow_failed_and_stop_events(
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
@ -506,6 +520,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
@ -602,6 +617,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_workflow_failed_and_stop_events(event)

View File

@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@ -32,6 +33,7 @@ from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunAgentLogEvent,
@ -362,6 +364,16 @@ class WorkflowBasedAppRunner:
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, GraphRunPausedEvent):
runtime_state = workflow_entry.graph_engine.graph_runtime_state
paused_nodes = runtime_state.get_paused_nodes()
self._publish_event(
QueueWorkflowPausedEvent(
reasons=event.reasons,
outputs=event.outputs,
paused_nodes=paused_nodes,
)
)
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs

View File

@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional["TraceQueueManager"] = None
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
@ -155,6 +155,7 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
"""
conversation_id: str | None = None
is_new_conversation: bool = False
parent_message_id: str | None = Field(
default=None,
description=(

View File

@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@ -46,6 +47,7 @@ class QueueEvent(StrEnum):
PING = "ping"
STOP = "stop"
RETRY = "retry"
PAUSE = "pause"
class AppQueueEvent(BaseModel):
@ -509,3 +511,14 @@ class WorkflowQueueMessage(QueueMessage):
"""
pass
class QueueWorkflowPausedEvent(AppQueueEvent):
"""
QueueWorkflowPausedEvent entity
"""
event: QueueEvent = QueueEvent.PAUSE
reasons: Sequence[PauseReason] = Field(default_factory=list)
outputs: Mapping[str, object] = Field(default_factory=dict)
paused_nodes: Sequence[str] = Field(default_factory=list)

View File

@ -8,6 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.human_input.entities import FormInput, UserAction
class AnnotationReplyAccount(BaseModel):
@ -83,6 +84,7 @@ class StreamEvent(StrEnum):
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
HUMAN_INPUT_REQUIRED = "human_input_required"
class StreamResponse(BaseModel):
@ -241,6 +243,45 @@ class WorkflowFinishStreamResponse(StreamResponse):
data: Data
class WorkflowPauseStreamResponse(StreamResponse):
"""
WorkflowPauseStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
workflow_run_id: str
paused_nodes: Sequence[str] = Field(default_factory=list)
outputs: Mapping[str, Any] = Field(default_factory=dict)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
workflow_run_id: str
data: Data
class HumanInputRequiredResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
web_app_form_token: str | None = None
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
workflow_run_id: str
data: Data
class NodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity

View File

@ -79,10 +79,11 @@ class MessageCycleManager:
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
is_first_message = self._application_generate_entity.is_new_conversation
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
thread: Thread | None = None
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
@ -98,9 +99,10 @@ class MessageCycleManager:
thread.daemon = True
thread.start()
return thread
if is_first_message:
self._application_generate_entity.is_new_conversation = False
return None
return thread
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():

View File

@ -90,6 +90,10 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
def recipients(self) -> list[HumanInputFormRecipientEntity]:
return list(self._recipients)
@property
def rendered_content(self) -> str:
return self._form_model.rendered_content
class _FormSubmissionImpl(FormSubmission):
def __init__(self, form_model: HumanInputForm):

View File

@ -3,7 +3,7 @@ from typing import Annotated, Literal, TypeAlias
from pydantic import BaseModel, Field
from core.workflow.nodes.human_input.entities import FormInput
from core.workflow.nodes.human_input.entities import FormInput, UserAction
class PauseReasonType(StrEnum):
@ -16,6 +16,9 @@ class HumanInputRequired(BaseModel):
form_id: str
form_content: str
inputs: list[FormInput] = Field(default_factory=list)
actions: list[UserAction] = Field(default_factory=list)
node_id: str
node_title: str
web_app_form_token: str | None = None

View File

@ -306,13 +306,13 @@ class GraphEngine:
for layer in self._layers:
try:
layer.initialize(read_only_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
except Exception:
logger.exception("Failed to initialize layer %s", layer.__class__.__name__)
try:
layer.on_graph_start()
except Exception as e:
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
except Exception:
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
@ -353,8 +353,8 @@ class GraphEngine:
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)
except Exception as e:
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
except Exception:
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
# Public property accessors for attributes that need external access
@property

View File

@ -8,9 +8,9 @@ import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from enum import StrEnum
from typing import Annotated, Literal, Optional
from typing import Annotated, Literal, Optional, Self
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.nodes.base import BaseNodeData
@ -152,12 +152,13 @@ class FormInputPlaceholder(BaseModel):
# TODO: How should we express JSON values?
value: str = ""
@field_validator("selector")
@classmethod
def _validate_selector(cls, selector: Sequence[str]) -> Sequence[str]:
if len(selector) < SELECTORS_LENGTH:
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={selector}")
return selector
@model_validator(mode="after")
def _validate_selector(self) -> Self:
if self.type == PlaceholderType.CONSTANT:
return self
if len(self.selector) < SELECTORS_LENGTH:
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
return self
class FormInput(BaseModel):

View File

@ -3,6 +3,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.node_events.base import NodeEventBase
@ -129,6 +130,18 @@ class HumanInputNode(Node[HumanInputNodeData]):
pause_requested_event = PauseRequestedEvent(reason=required_event)
return pause_requested_event
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
node_data = self._node_data
return HumanInputRequired(
form_id=form_entity.id,
form_content=form_entity.rendered_content,
inputs=node_data.inputs,
actions=node_data.user_actions,
node_id=self.id,
node_title=node_data.title,
web_app_form_token=form_entity.web_app_token,
)
def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult:
try:
params = FormCreateParams(
@ -146,7 +159,6 @@ class HumanInputNode(Node[HumanInputNodeData]):
self.id,
form_entity.id,
)
yield self._human
yield self._form_to_pause_event(form_entity)
except Exception as e:
logger.exception("Human Input node failed to execute, node_id=%s", self.id)

View File

@ -52,6 +52,12 @@ class HumanInputFormEntity(abc.ABC):
@abc.abstractmethod
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
@property
@abc.abstractmethod
def rendered_content(self) -> str:
"""Rendered markdown content associated with the form."""
...
class HumanInputFormRecipientEntity(abc.ABC):
@property

View File

@ -354,6 +354,11 @@ class GraphRuntimeState:
self._paused_nodes.add(node_id)
def get_paused_nodes(self) -> list[str]:
"""Retrieve the list of paused nodes without mutating internal state."""
return list(self._paused_nodes)
def consume_paused_nodes(self) -> list[str]:
"""Retrieve and clear the list of paused nodes awaiting resume."""

View File

@ -350,9 +350,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
# Check if workflow is in RUNNING status
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
# TODO(QuantumGhost): It seems that the persistence of `WorkflowRun.status`
# happens before the execution of GraphLayer
if workflow_run.status not in {WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.PAUSED}:
raise _WorkflowRunError(
f"Only WorkflowRun with RUNNING status can be paused, "
f"Only WorkflowRun with RUNNING or PAUSED status can be paused, "
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
)
#

View File

@ -0,0 +1,90 @@
#!/usr/bin/env python3
"""Subscribe to workflow run events from the Redis broadcast channel."""
from __future__ import annotations
import argparse
import json
import sys
import uuid
from collections.abc import Mapping
from app_factory import create_flask_app_with_configs
from core.app.apps.message_generator import MessageGenerator
from extensions import ext_redis
from models.model import AppMode
def _parse_app_mode(value: str) -> AppMode:
try:
return AppMode.value_of(value)
except ValueError as exc: # pragma: no cover - argparse rewrites the message
raise argparse.ArgumentTypeError(str(exc)) from exc
def _parse_workflow_run_id(value: str) -> str:
try:
workflow_uuid = uuid.UUID(value)
except ValueError as exc: # pragma: no cover - argparse rewrites the message
raise argparse.ArgumentTypeError("workflow run id must be a valid UUID") from exc
return str(workflow_uuid)
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Subscribe to Redis broadcast channel events for a workflow run and print them."
)
parser.add_argument(
"--workflow-run-id",
required=True,
type=_parse_workflow_run_id,
help="Workflow run identifier whose stream output should be tailed.",
)
parser.add_argument(
"--app-mode",
required=True,
type=_parse_app_mode,
choices=list(AppMode),
help="App mode the workflow ran under (determines the broadcast channel name).",
)
parser.add_argument(
"--idle-timeout",
type=float,
default=300.0,
help="Stop listening after this many seconds without events (default: 300).",
)
return parser
def _initialize_redis() -> None:
# A lightweight Flask app is enough to reuse the existing Redis initialization code path.
app = create_flask_app_with_configs()
ext_redis.init_app(app)
def _print_event(event: Mapping | str) -> None:
if isinstance(event, Mapping):
payload = json.dumps(event, ensure_ascii=False)
else:
payload = event
print(payload)
sys.stdout.flush()
def main() -> int:
parser = _build_parser()
args = parser.parse_args()
_initialize_redis()
events = MessageGenerator.retrieve_events(args.app_mode, args.workflow_run_id, idle_timeout=args.idle_timeout)
try:
for event in events:
_print_event(event)
except KeyboardInterrupt:
return 130
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -162,6 +162,7 @@ class AccountService:
def get_account_jwt_token(account: Account) -> str:
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp = int(exp_dt.timestamp())
breakpoint()
payload = {
"user_id": account.id,
"exp": exp,

View File

@ -1,7 +1,7 @@
import contextlib
import logging
import uuid
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from enum import StrEnum
from typing import Annotated, Any, TypeAlias, Union
@ -157,15 +157,7 @@ class _ChatflowRunner:
if not exec_params.streaming:
return response
topic = chat_generator.get_response_topic(AppMode.ADVANCED_CHAT, workflow_run_id)
for event in response:
try:
payload = json.dumps(event)
except TypeError:
logging.exception("error while encoding event")
continue
topic.publish(payload.encode())
_publish_streaming_response(response, workflow_run_id)
def _resolve_user(self) -> Account | EndUser:
user_params = self._exec_params.user
@ -194,6 +186,36 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun
return session.get(EndUser, workflow_run.created_by)
def _coerce_uuid(value: Any) -> uuid.UUID | None:
if isinstance(value, uuid.UUID):
return value
if value is None:
return None
try:
return uuid.UUID(str(value))
except (ValueError, TypeError):
logger.warning("Invalid workflow_run_id value: %s", value)
return None
def _publish_streaming_response(response_stream: Iterable[Any], workflow_run_id: Any) -> None:
workflow_run_uuid = _coerce_uuid(workflow_run_id)
if workflow_run_uuid is None:
logger.warning("Unable to publish streaming response without valid workflow_run_id: %s", workflow_run_id)
return
topic = AdvancedChatAppGenerator.get_response_topic(AppMode.ADVANCED_CHAT, workflow_run_uuid)
for event in response_stream:
try:
payload = json.dumps(event)
except TypeError:
logger.exception("error while encoding event")
continue
topic.publish(payload.encode())
@shared_task(queue="chatflow_execute")
def chatflow_execute_task(payload: str) -> Mapping[str, Any] | None:
exec_params = ChatflowExecutionParams.model_validate_json(payload)
@ -300,7 +322,7 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
generator = AdvancedChatAppGenerator()
try:
generator.resume(
response = generator.resume(
app_model=app_model,
workflow=workflow,
user=user,
@ -315,3 +337,13 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
except Exception:
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
publish_uuid = _coerce_uuid(generate_entity.workflow_run_id) or _coerce_uuid(workflow_run_id)
if publish_uuid is None:
logger.warning(
"Unable to publish streaming response for workflow run %s due to missing workflow_run_id",
workflow_run_id,
)
else:
_publish_streaming_response(response, publish_uuid)

View File

@ -0,0 +1,139 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps import message_based_app_generator
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.task_pipeline import message_cycle_manager
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from models.model import AppMode, Conversation, Message
def _make_app_config() -> WorkflowUIBasedAppConfig:
return WorkflowUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=AppMode.ADVANCED_CHAT,
workflow_id="workflow-id",
additional_features=AppAdditionalFeatures(),
variables=[],
)
def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity:
return AdvancedChatAppGenerateEntity(
task_id="task-id",
app_config=app_config,
file_upload_config=None,
conversation_id=None,
inputs={},
query="hello",
files=[],
parent_message_id=None,
user_id="user-id",
stream=True,
invoke_from=InvokeFrom.WEB_APP,
extras={},
workflow_run_id="workflow-run-id",
)
@pytest.fixture(autouse=True)
def _mock_db_session(monkeypatch):
session = MagicMock()
def refresh_side_effect(obj):
if isinstance(obj, Conversation) and obj.id is None:
obj.id = "generated-conversation-id"
if isinstance(obj, Message) and obj.id is None:
obj.id = "generated-message-id"
session.refresh.side_effect = refresh_side_effect
session.add.return_value = None
session.commit.return_value = None
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
return session
def test_init_generate_records_sets_conversation_metadata():
app_config = _make_app_config()
entity = _make_generate_entity(app_config)
generator = AdvancedChatAppGenerator()
conversation, _ = generator._init_generate_records(entity, conversation=None)
assert entity.conversation_id == "generated-conversation-id"
assert conversation.id == "generated-conversation-id"
assert entity.is_new_conversation is True
def test_init_generate_records_marks_existing_conversation():
app_config = _make_app_config()
entity = _make_generate_entity(app_config)
existing_conversation = Conversation(
app_id=app_config.app_id,
app_model_config_id=None,
model_provider=None,
override_model_configs=None,
model_id=None,
mode=app_config.app_mode.value,
name="existing",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
from_source="api",
from_end_user_id="user-id",
from_account_id=None,
)
existing_conversation.id = "existing-conversation-id"
generator = AdvancedChatAppGenerator()
conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation)
assert entity.conversation_id == "existing-conversation-id"
assert conversation is existing_conversation
assert entity.is_new_conversation is False
def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
app_config = _make_app_config()
entity = _make_generate_entity(app_config)
entity.conversation_id = "existing-conversation-id"
entity.is_new_conversation = True
entity.extras = {"auto_generate_conversation_name": True}
captured = {}
class DummyThread:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.started = False
def start(self):
self.started = True
def fake_thread(**kwargs):
thread = DummyThread(**kwargs)
captured["thread"] = thread
return thread
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")
assert thread is captured["thread"]
assert thread.started is True
assert entity.is_new_conversation is False

View File

@ -0,0 +1,148 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph_events.graph import GraphRunPausedEvent
from core.workflow.nodes.human_input.entities import FormInput, FormInputType, UserAction
from core.workflow.system_variable import SystemVariable
from models.account import Account
class _RecordingWorkflowAppRunner(WorkflowAppRunner):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.published_events = []
def _publish_event(self, event):
self.published_events.append(event)
class _FakeRuntimeState:
def get_paused_nodes(self):
return ["node-pause-1"]
def _build_runner():
app_entity = SimpleNamespace(
app_config=SimpleNamespace(app_id="app-id"),
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
single_iteration_run=None,
single_loop_run=None,
workflow_execution_id="run-id",
user_id="user-id",
)
workflow = SimpleNamespace(
graph_dict={},
tenant_id="tenant-id",
environment_variables={},
id="workflow-id",
)
queue_manager = SimpleNamespace(publish=lambda event, pub_from: None)
return _RecordingWorkflowAppRunner(
application_generate_entity=app_entity,
queue_manager=queue_manager,
variable_loader=MagicMock(),
workflow=workflow,
system_user_id="sys-user",
root_node_id=None,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
graph_engine_layers=(),
graph_runtime_state=None,
)
def test_graph_run_paused_event_emits_queue_pause_event():
runner = _build_runner()
reason = HumanInputRequired(
form_id="form-1",
form_content="content",
inputs=[],
actions=[],
node_id="node-human",
node_title="Human Step",
web_app_form_token="tok",
)
event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"})
workflow_entry = SimpleNamespace(
graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()),
)
runner._handle_event(workflow_entry, event)
assert len(runner.published_events) == 1
queue_event = runner.published_events[0]
assert isinstance(queue_event, QueueWorkflowPausedEvent)
assert queue_event.reasons == [reason]
assert queue_event.outputs == {"foo": "bar"}
assert queue_event.paused_nodes == ["node-pause-1"]
def _build_converter():
application_generate_entity = SimpleNamespace(
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
)
system_variables = SystemVariable(
user_id="user",
app_id="app-id",
workflow_id="workflow-id",
workflow_execution_id="run-id",
)
user = MagicMock(spec=Account)
user.id = "account-id"
user.name = "Tester"
user.email = "tester@example.com"
return WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=system_variables,
)
def test_queue_workflow_paused_event_to_stream_responses():
converter = _build_converter()
converter.workflow_start_to_stream_response(task_id="task", workflow_run_id="run-id", workflow_id="workflow-id")
reason = HumanInputRequired(
form_id="form-1",
form_content="Rendered",
inputs=[
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", placeholder=None),
],
actions=[UserAction(id="approve", title="Approve")],
node_id="node-id",
node_title="Human Step",
web_app_form_token="token",
)
queue_event = QueueWorkflowPausedEvent(
reasons=[reason],
outputs={"answer": "value"},
paused_nodes=["node-id"],
)
responses = converter.workflow_pause_to_stream_response(event=queue_event, task_id="task")
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
pause_resp = responses[-1]
assert pause_resp.workflow_run_id == "run-id"
assert pause_resp.data.paused_nodes == ["node-id"]
assert pause_resp.data.outputs == {"answer": "value"}
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
assert isinstance(responses[0], HumanInputRequiredResponse)
hi_resp = responses[0]
assert hi_resp.data.form_id == "form-1"
assert hi_resp.data.node_id == "node-id"
assert hi_resp.data.node_title == "Human Step"
assert hi_resp.data.inputs[0].output_variable_name == "field"
assert hi_resp.data.actions[0].id == "approve"

View File

@ -0,0 +1,141 @@
import json
from collections.abc import Callable
from dataclasses import dataclass
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.layers.pause_state_persist_layer import (
WorkflowResumptionContext,
_AdvancedChatAppGenerateEntityWrapper,
_WorkflowGenerateEntityWrapper,
)
from core.ops.ops_trace_manager import TraceQueueManager
from models.model import AppMode
class TraceQueueManagerStub(TraceQueueManager):
"""Minimal TraceQueueManager stub that avoids Flask dependencies."""
def __init__(self):
# Skip parent initialization to avoid starting timers or accessing Flask globals.
pass
def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig:
return WorkflowUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=app_mode,
workflow_id=f"{app_mode.value}-workflow-id",
)
def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity:
return WorkflowAppGenerateEntity(
task_id="workflow-task",
app_config=_build_workflow_app_config(AppMode.WORKFLOW),
inputs={"topic": "serialization"},
files=[],
user_id="user-workflow",
stream=True,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=1,
trace_manager=trace_manager,
workflow_execution_id="workflow-exec-id",
extras={"external_trace_id": "trace-id"},
)
def _create_advanced_chat_generate_entity(trace_manager: TraceQueueManager | None = None) -> AdvancedChatAppGenerateEntity:
return AdvancedChatAppGenerateEntity(
task_id="advanced-task",
app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT),
conversation_id="conversation-id",
inputs={"topic": "roundtrip"},
files=[],
user_id="user-advanced",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
query="Explain serialization",
extras={"auto_generate_conversation_name": True},
trace_manager=trace_manager,
workflow_run_id="workflow-run-id",
)
def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager():
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
serialized = entity.model_dump_json()
payload = json.loads(serialized)
assert "trace_manager" not in payload
restored = WorkflowAppGenerateEntity.model_validate_json(serialized)
assert restored.model_dump() == entity.model_dump()
assert restored.trace_manager is None
def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager():
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
serialized = entity.model_dump_json()
payload = json.loads(serialized)
assert "trace_manager" not in payload
restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized)
assert restored.model_dump() == entity.model_dump()
assert restored.trace_manager is None
@dataclass(frozen=True)
class ResumptionContextCase:
name: str
context_factory: Callable[[], tuple[WorkflowResumptionContext, type]]
def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]:
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
context = WorkflowResumptionContext(
serialized_graph_runtime_state=json.dumps({"state": "workflow"}),
generate_entity=_WorkflowGenerateEntityWrapper(entity=entity),
)
return context, WorkflowAppGenerateEntity
def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]:
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
context = WorkflowResumptionContext(
serialized_graph_runtime_state=json.dumps({"state": "advanced"}),
generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity),
)
return context, AdvancedChatAppGenerateEntity
@pytest.mark.parametrize(
"case",
[
pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"),
pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"),
],
)
def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase):
context, expected_type = case.context_factory()
serialized = context.dumps()
restored = WorkflowResumptionContext.loads(serialized)
assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state
entity = restored.get_generate_entity()
assert isinstance(entity, expected_type)
assert entity.model_dump() == context.get_generate_entity().model_dump()
assert entity.trace_manager is None

View File

@ -34,6 +34,7 @@ class _InMemoryFormRecipient(HumanInputFormRecipientEntity):
@dataclass
class _InMemoryFormEntity(HumanInputFormEntity):
form_id: str
rendered: str
token: str | None = None
@property
@ -48,6 +49,10 @@ class _InMemoryFormEntity(HumanInputFormEntity):
def recipients(self) -> list[HumanInputFormRecipientEntity]:
return []
@property
def rendered_content(self) -> str:
return self.rendered
class _InMemoryFormSubmission(FormSubmission):
def __init__(self, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
@ -76,7 +81,7 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
self.created_params.append(params)
self._form_counter += 1
form_id = f"form-{self._form_counter}"
entity = _InMemoryFormEntity(form_id=form_id, token=f"token-{form_id}")
entity = _InMemoryFormEntity(form_id=form_id, rendered=params.rendered_content, token=f"token-{form_id}")
self.created_forms.append(entity)
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity
return entity

View File

@ -248,6 +248,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_create_repo.create_form.return_value = mock_form_entity
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:

View File

@ -193,6 +193,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_create_repo.create_form.return_value = mock_form_entity
def graph_factory() -> tuple[Graph, GraphRuntimeState]:

View File

@ -52,6 +52,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
repo.get_form.return_value = form_entity
return repo
@ -63,6 +64,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository:
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
repo.create_form.return_value = form_entity
repo.get_form.return_value = None
return repo

View File

@ -0,0 +1,38 @@
from __future__ import annotations
import json
import uuid
from unittest.mock import MagicMock
import pytest
from tasks.app_generate.workflow_execute_task import _publish_streaming_response
@pytest.fixture
def mock_topic(mocker) -> MagicMock:
topic = MagicMock()
mocker.patch(
"tasks.app_generate.workflow_execute_task.AdvancedChatAppGenerator.get_response_topic",
return_value=topic,
)
return topic
def test_publish_streaming_response_with_uuid(mock_topic: MagicMock):
workflow_run_id = uuid.uuid4()
response_stream = iter([{"event": "foo"}, "ping"])
_publish_streaming_response(response_stream, workflow_run_id)
payloads = [call.args[0] for call in mock_topic.publish.call_args_list]
assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()]
def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock):
workflow_run_id = uuid.uuid4()
response_stream = iter([{"event": "bar"}])
_publish_streaming_response(response_stream, str(workflow_run_id))
mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode())