mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
WIP: api debugging
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask import request
|
||||
@ -409,7 +410,19 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
# Check if workflow is suspended
|
||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
if not is_paused:
|
||||
return {"is_suspended": False, "paused_at": None, "paused_nodes": [], "pending_human_inputs": []}, 200
|
||||
return {
|
||||
"is_suspended": False,
|
||||
"paused_at": None,
|
||||
"paused_nodes": [],
|
||||
"pending_human_inputs": [],
|
||||
"pause_reasons": [],
|
||||
}, 200
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
pause_reasons: list[dict[str, object]] = []
|
||||
if pause_entity:
|
||||
for reason in pause_entity.get_pause_reasons():
|
||||
pause_reasons.append(reason.model_dump(mode="json"))
|
||||
|
||||
# Get pending Human Input forms for this workflow run
|
||||
service = HumanInputFormService(db.session())
|
||||
@ -421,6 +434,7 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
"paused_at": workflow_run.created_at.isoformat() + "Z" if workflow_run.created_at else None,
|
||||
"paused_nodes": [],
|
||||
"pending_human_inputs": [],
|
||||
"pause_reasons": pause_reasons,
|
||||
}
|
||||
|
||||
# Add pending human input forms
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
Console/Studio Human Input Form APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
@ -157,32 +158,18 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
if workflow_run.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run.id,
|
||||
workflow_run=workflow_run,
|
||||
creator_user=user,
|
||||
)
|
||||
|
||||
# TODO: should we just return here? or yield a WorkflowFinishStreamResponse?
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def generate_events() -> Generator[str, None, None]:
|
||||
"""Generate SSE events for workflow execution."""
|
||||
try:
|
||||
# TODO: Implement actual event streaming
|
||||
# This would connect to the workflow execution engine
|
||||
# and stream real-time events
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
# For demo purposes, send a basic event
|
||||
yield f"data: {{'event': 'workflow_resumed', 'task_id': '{task_id}'}}\n\n"
|
||||
|
||||
# In real implementation, this would:
|
||||
# 1. Connect to workflow execution engine
|
||||
# 2. Stream real-time execution events
|
||||
# 3. Handle client disconnection
|
||||
# 4. Clean up resources on completion
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error streaming events for task %s", task_id)
|
||||
yield f"data: {{'error': 'Stream error: {str(e)}'}}\n\n"
|
||||
else:
|
||||
# TODO: SSE from Redis PubSub
|
||||
msg_generator = MessageGenerator()
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
@ -42,6 +42,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
@ -525,6 +526,19 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
self,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
)
|
||||
for response in responses:
|
||||
yield response
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_failed_event(
|
||||
@ -659,6 +673,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
||||
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
|
||||
# Node events
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
@ -747,6 +762,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
case QueueWorkflowPausedEvent():
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
|
||||
@ -19,9 +19,11 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
HumanInputRequiredResponse,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
@ -31,7 +33,9 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
@ -40,6 +44,7 @@ from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@ -265,10 +270,56 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_pause_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
task_id: str,
|
||||
) -> list[StreamResponse]:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
encoded_outputs = self._encode_outputs(event.outputs) or {}
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
for reason in event.reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
responses.append(
|
||||
HumanInputRequiredResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id=reason.form_id,
|
||||
node_id=reason.node_id,
|
||||
node_title=reason.node_title,
|
||||
form_content=reason.form_content,
|
||||
inputs=reason.inputs,
|
||||
actions=reason.actions,
|
||||
web_app_form_token=reason.web_app_form_token,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
responses.append(
|
||||
WorkflowPauseStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id=run_id,
|
||||
paused_nodes=list(event.paused_nodes),
|
||||
outputs=encoded_outputs,
|
||||
reasons=pause_reasons,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return responses
|
||||
|
||||
@classmethod
|
||||
def workflow_run_result_to_finish_response(
|
||||
cls,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
creator_user: Account | EndUser,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
@ -294,7 +345,7 @@ class WorkflowResponseConverter:
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id, # TODO
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
|
||||
@ -163,6 +163,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
query = application_generate_entity.query or "New conversation"
|
||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||
|
||||
created_new_conversation = conversation is None
|
||||
try:
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
@ -239,6 +240,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
db.session.add_all(message_files)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
application_generate_entity.conversation_id = conversation.id
|
||||
application_generate_entity.is_new_conversation = created_new_conversation
|
||||
return conversation, message
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
|
||||
@ -32,6 +32,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
@ -440,6 +441,19 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
self,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
self._ensure_workflow_initialized()
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
)
|
||||
yield from response
|
||||
|
||||
def _handle_workflow_failed_and_stop_events(
|
||||
self,
|
||||
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
||||
@ -506,6 +520,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
||||
# Node events
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
@ -602,6 +617,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
break
|
||||
case QueueWorkflowPausedEvent():
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
|
||||
@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@ -32,6 +33,7 @@ from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
@ -362,6 +364,16 @@ class WorkflowBasedAppRunner:
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, GraphRunPausedEvent):
|
||||
runtime_state = workflow_entry.graph_engine.graph_runtime_state
|
||||
paused_nodes = runtime_state.get_paused_nodes()
|
||||
self._publish_event(
|
||||
QueueWorkflowPausedEvent(
|
||||
reasons=event.reasons,
|
||||
outputs=event.outputs,
|
||||
paused_nodes=paused_nodes,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
|
||||
@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional["TraceQueueManager"] = None
|
||||
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
@ -155,6 +155,7 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
conversation_id: str | None = None
|
||||
is_new_conversation: bool = False
|
||||
parent_message_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
|
||||
@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
@ -46,6 +47,7 @@ class QueueEvent(StrEnum):
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
PAUSE = "pause"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
@ -509,3 +511,14 @@ class WorkflowQueueMessage(QueueMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueWorkflowPausedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowPausedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PAUSE
|
||||
reasons: Sequence[PauseReason] = Field(default_factory=list)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
|
||||
@ -8,6 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
@ -83,6 +84,7 @@ class StreamEvent(StrEnum):
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
HUMAN_INPUT_REQUIRED = "human_input_required"
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
@ -241,6 +243,45 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowPauseStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowPauseStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
workflow_run_id: str
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputRequiredResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
web_app_form_token: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class NodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
|
||||
@ -79,10 +79,11 @@ class MessageCycleManager:
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
is_first_message = self._application_generate_entity.is_new_conversation
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||
|
||||
thread: Thread | None = None
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
# time.sleep not block other logic
|
||||
@ -98,9 +99,10 @@ class MessageCycleManager:
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
if is_first_message:
|
||||
self._application_generate_entity.is_new_conversation = False
|
||||
|
||||
return None
|
||||
return thread
|
||||
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
|
||||
@ -90,6 +90,10 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
return list(self._recipients)
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self._form_model.rendered_content
|
||||
|
||||
|
||||
class _FormSubmissionImpl(FormSubmission):
|
||||
def __init__(self, form_model: HumanInputForm):
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class PauseReasonType(StrEnum):
|
||||
@ -16,6 +16,9 @@ class HumanInputRequired(BaseModel):
|
||||
form_id: str
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
actions: list[UserAction] = Field(default_factory=list)
|
||||
node_id: str
|
||||
node_title: str
|
||||
web_app_form_token: str | None = None
|
||||
|
||||
|
||||
|
||||
@ -306,13 +306,13 @@ class GraphEngine:
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.initialize(read_only_state, self._command_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize layer %s", layer.__class__.__name__)
|
||||
|
||||
try:
|
||||
layer.on_graph_start()
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
|
||||
except Exception:
|
||||
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
@ -353,8 +353,8 @@ class GraphEngine:
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
|
||||
except Exception:
|
||||
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
|
||||
|
||||
# Public property accessors for attributes that need external access
|
||||
@property
|
||||
|
||||
0
api/core/workflow/graph_events/human_input.py
Normal file
0
api/core/workflow/graph_events/human_input.py
Normal file
@ -8,9 +8,9 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, Optional
|
||||
from typing import Annotated, Literal, Optional, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
@ -152,12 +152,13 @@ class FormInputPlaceholder(BaseModel):
|
||||
# TODO: How should we express JSON values?
|
||||
value: str = ""
|
||||
|
||||
@field_validator("selector")
|
||||
@classmethod
|
||||
def _validate_selector(cls, selector: Sequence[str]) -> Sequence[str]:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={selector}")
|
||||
return selector
|
||||
@model_validator(mode="after")
|
||||
def _validate_selector(self) -> Self:
|
||||
if self.type == PlaceholderType.CONSTANT:
|
||||
return self
|
||||
if len(self.selector) < SELECTORS_LENGTH:
|
||||
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
|
||||
return self
|
||||
|
||||
|
||||
class FormInput(BaseModel):
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events.base import NodeEventBase
|
||||
@ -129,6 +130,18 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
return pause_requested_event
|
||||
|
||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
||||
node_data = self._node_data
|
||||
return HumanInputRequired(
|
||||
form_id=form_entity.id,
|
||||
form_content=form_entity.rendered_content,
|
||||
inputs=node_data.inputs,
|
||||
actions=node_data.user_actions,
|
||||
node_id=self.id,
|
||||
node_title=node_data.title,
|
||||
web_app_form_token=form_entity.web_app_token,
|
||||
)
|
||||
|
||||
def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult:
|
||||
try:
|
||||
params = FormCreateParams(
|
||||
@ -146,7 +159,6 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
self.id,
|
||||
form_entity.id,
|
||||
)
|
||||
yield self._human
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
except Exception as e:
|
||||
logger.exception("Human Input node failed to execute, node_id=%s", self.id)
|
||||
|
||||
@ -52,6 +52,12 @@ class HumanInputFormEntity(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def rendered_content(self) -> str:
|
||||
"""Rendered markdown content associated with the form."""
|
||||
...
|
||||
|
||||
|
||||
class HumanInputFormRecipientEntity(abc.ABC):
|
||||
@property
|
||||
|
||||
@ -354,6 +354,11 @@ class GraphRuntimeState:
|
||||
|
||||
self._paused_nodes.add(node_id)
|
||||
|
||||
def get_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve the list of paused nodes without mutating internal state."""
|
||||
|
||||
return list(self._paused_nodes)
|
||||
|
||||
def consume_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear the list of paused nodes awaiting resume."""
|
||||
|
||||
|
||||
@ -350,9 +350,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
# Check if workflow is in RUNNING status
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
# TODO(QuantumGhost): It seems that the persistence of `WorkflowRun.status`
|
||||
# happens before the execution of GraphLayer
|
||||
if workflow_run.status not in {WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.PAUSED}:
|
||||
raise _WorkflowRunError(
|
||||
f"Only WorkflowRun with RUNNING status can be paused, "
|
||||
f"Only WorkflowRun with RUNNING or PAUSED status can be paused, "
|
||||
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
|
||||
)
|
||||
#
|
||||
|
||||
90
api/scripts/workflow_event_subscriber.py
Normal file
90
api/scripts/workflow_event_subscriber.py
Normal file
@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Subscribe to workflow run events from the Redis broadcast channel."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
|
||||
from app_factory import create_flask_app_with_configs
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from extensions import ext_redis
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _parse_app_mode(value: str) -> AppMode:
|
||||
try:
|
||||
return AppMode.value_of(value)
|
||||
except ValueError as exc: # pragma: no cover - argparse rewrites the message
|
||||
raise argparse.ArgumentTypeError(str(exc)) from exc
|
||||
|
||||
|
||||
def _parse_workflow_run_id(value: str) -> str:
|
||||
try:
|
||||
workflow_uuid = uuid.UUID(value)
|
||||
except ValueError as exc: # pragma: no cover - argparse rewrites the message
|
||||
raise argparse.ArgumentTypeError("workflow run id must be a valid UUID") from exc
|
||||
return str(workflow_uuid)
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Subscribe to Redis broadcast channel events for a workflow run and print them."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workflow-run-id",
|
||||
required=True,
|
||||
type=_parse_workflow_run_id,
|
||||
help="Workflow run identifier whose stream output should be tailed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--app-mode",
|
||||
required=True,
|
||||
type=_parse_app_mode,
|
||||
choices=list(AppMode),
|
||||
help="App mode the workflow ran under (determines the broadcast channel name).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--idle-timeout",
|
||||
type=float,
|
||||
default=300.0,
|
||||
help="Stop listening after this many seconds without events (default: 300).",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _initialize_redis() -> None:
|
||||
# A lightweight Flask app is enough to reuse the existing Redis initialization code path.
|
||||
app = create_flask_app_with_configs()
|
||||
ext_redis.init_app(app)
|
||||
|
||||
|
||||
def _print_event(event: Mapping | str) -> None:
|
||||
if isinstance(event, Mapping):
|
||||
payload = json.dumps(event, ensure_ascii=False)
|
||||
else:
|
||||
payload = event
|
||||
print(payload)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
_initialize_redis()
|
||||
|
||||
events = MessageGenerator.retrieve_events(args.app_mode, args.workflow_run_id, idle_timeout=args.idle_timeout)
|
||||
try:
|
||||
for event in events:
|
||||
_print_event(event)
|
||||
except KeyboardInterrupt:
|
||||
return 130
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@ -162,6 +162,7 @@ class AccountService:
|
||||
def get_account_jwt_token(account: Account) -> str:
|
||||
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
exp = int(exp_dt.timestamp())
|
||||
breakpoint()
|
||||
payload = {
|
||||
"user_id": account.id,
|
||||
"exp": exp,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Iterable, Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, TypeAlias, Union
|
||||
|
||||
@ -157,15 +157,7 @@ class _ChatflowRunner:
|
||||
if not exec_params.streaming:
|
||||
return response
|
||||
|
||||
topic = chat_generator.get_response_topic(AppMode.ADVANCED_CHAT, workflow_run_id)
|
||||
for event in response:
|
||||
try:
|
||||
payload = json.dumps(event)
|
||||
except TypeError:
|
||||
logging.exception("error while encoding event")
|
||||
continue
|
||||
|
||||
topic.publish(payload.encode())
|
||||
_publish_streaming_response(response, workflow_run_id)
|
||||
|
||||
def _resolve_user(self) -> Account | EndUser:
|
||||
user_params = self._exec_params.user
|
||||
@ -194,6 +186,36 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun
|
||||
return session.get(EndUser, workflow_run.created_by)
|
||||
|
||||
|
||||
def _coerce_uuid(value: Any) -> uuid.UUID | None:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return uuid.UUID(str(value))
|
||||
except (ValueError, TypeError):
|
||||
logger.warning("Invalid workflow_run_id value: %s", value)
|
||||
return None
|
||||
|
||||
|
||||
def _publish_streaming_response(response_stream: Iterable[Any], workflow_run_id: Any) -> None:
|
||||
workflow_run_uuid = _coerce_uuid(workflow_run_id)
|
||||
if workflow_run_uuid is None:
|
||||
logger.warning("Unable to publish streaming response without valid workflow_run_id: %s", workflow_run_id)
|
||||
return
|
||||
|
||||
topic = AdvancedChatAppGenerator.get_response_topic(AppMode.ADVANCED_CHAT, workflow_run_uuid)
|
||||
for event in response_stream:
|
||||
try:
|
||||
payload = json.dumps(event)
|
||||
except TypeError:
|
||||
logger.exception("error while encoding event")
|
||||
continue
|
||||
|
||||
topic.publish(payload.encode())
|
||||
|
||||
|
||||
@shared_task(queue="chatflow_execute")
|
||||
def chatflow_execute_task(payload: str) -> Mapping[str, Any] | None:
|
||||
exec_params = ChatflowExecutionParams.model_validate_json(payload)
|
||||
@ -300,7 +322,7 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
try:
|
||||
generator.resume(
|
||||
response = generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
@ -315,3 +337,13 @@ def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
except Exception:
|
||||
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
if generate_entity.stream:
|
||||
publish_uuid = _coerce_uuid(generate_entity.workflow_run_id) or _coerce_uuid(workflow_run_id)
|
||||
if publish_uuid is None:
|
||||
logger.warning(
|
||||
"Unable to publish streaming response for workflow run %s due to missing workflow_run_id",
|
||||
workflow_run_id,
|
||||
)
|
||||
else:
|
||||
_publish_streaming_response(response, publish_uuid)
|
||||
|
||||
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.task_pipeline import message_cycle_manager
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
def _make_app_config() -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_id="workflow-id",
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity:
|
||||
return AdvancedChatAppGenerateEntity(
|
||||
task_id="task-id",
|
||||
app_config=app_config,
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
parent_message_id=None,
|
||||
user_id="user-id",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
workflow_run_id="workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_db_session(monkeypatch):
|
||||
session = MagicMock()
|
||||
|
||||
def refresh_side_effect(obj):
|
||||
if isinstance(obj, Conversation) and obj.id is None:
|
||||
obj.id = "generated-conversation-id"
|
||||
if isinstance(obj, Message) and obj.id is None:
|
||||
obj.id = "generated-message-id"
|
||||
|
||||
session.refresh.side_effect = refresh_side_effect
|
||||
session.add.return_value = None
|
||||
session.commit.return_value = None
|
||||
|
||||
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
|
||||
return session
|
||||
|
||||
|
||||
def test_init_generate_records_sets_conversation_metadata():
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
|
||||
|
||||
def test_init_generate_records_marks_existing_conversation():
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
|
||||
existing_conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
override_model_configs=None,
|
||||
model_id=None,
|
||||
mode=app_config.app_mode.value,
|
||||
name="existing",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
from_source="api",
|
||||
from_end_user_id="user-id",
|
||||
from_account_id=None,
|
||||
)
|
||||
existing_conversation.id = "existing-conversation-id"
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation)
|
||||
|
||||
assert entity.conversation_id == "existing-conversation-id"
|
||||
assert conversation is existing_conversation
|
||||
assert entity.is_new_conversation is False
|
||||
|
||||
|
||||
def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
entity.conversation_id = "existing-conversation-id"
|
||||
entity.is_new_conversation = True
|
||||
entity.extras = {"auto_generate_conversation_name": True}
|
||||
|
||||
captured = {}
|
||||
|
||||
class DummyThread:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
def fake_thread(**kwargs):
|
||||
thread = DummyThread(**kwargs)
|
||||
captured["thread"] = thread
|
||||
return thread
|
||||
|
||||
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
|
||||
|
||||
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
|
||||
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")
|
||||
|
||||
assert thread is captured["thread"]
|
||||
assert thread.started is True
|
||||
assert entity.is_new_conversation is False
|
||||
148
api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
Normal file
148
api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
Normal file
@ -0,0 +1,148 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.nodes.human_input.entities import FormInput, FormInputType, UserAction
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.account import Account
|
||||
|
||||
|
||||
class _RecordingWorkflowAppRunner(WorkflowAppRunner):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.published_events = []
|
||||
|
||||
def _publish_event(self, event):
|
||||
self.published_events.append(event)
|
||||
|
||||
|
||||
class _FakeRuntimeState:
|
||||
def get_paused_nodes(self):
|
||||
return ["node-pause-1"]
|
||||
|
||||
|
||||
def _build_runner():
|
||||
app_entity = SimpleNamespace(
|
||||
app_config=SimpleNamespace(app_id="app-id"),
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
single_iteration_run=None,
|
||||
single_loop_run=None,
|
||||
workflow_execution_id="run-id",
|
||||
user_id="user-id",
|
||||
)
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={},
|
||||
tenant_id="tenant-id",
|
||||
environment_variables={},
|
||||
id="workflow-id",
|
||||
)
|
||||
queue_manager = SimpleNamespace(publish=lambda event, pub_from: None)
|
||||
return _RecordingWorkflowAppRunner(
|
||||
application_generate_entity=app_entity,
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="sys-user",
|
||||
root_node_id=None,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=(),
|
||||
graph_runtime_state=None,
|
||||
)
|
||||
|
||||
|
||||
def test_graph_run_paused_event_emits_queue_pause_event():
|
||||
runner = _build_runner()
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-human",
|
||||
node_title="Human Step",
|
||||
web_app_form_token="tok",
|
||||
)
|
||||
event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"})
|
||||
workflow_entry = SimpleNamespace(
|
||||
graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()),
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry, event)
|
||||
|
||||
assert len(runner.published_events) == 1
|
||||
queue_event = runner.published_events[0]
|
||||
assert isinstance(queue_event, QueueWorkflowPausedEvent)
|
||||
assert queue_event.reasons == [reason]
|
||||
assert queue_event.outputs == {"foo": "bar"}
|
||||
assert queue_event.paused_nodes == ["node-pause-1"]
|
||||
|
||||
|
||||
def _build_converter():
|
||||
application_generate_entity = SimpleNamespace(
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
|
||||
)
|
||||
system_variables = SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
)
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "account-id"
|
||||
user.name = "Tester"
|
||||
user.email = "tester@example.com"
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_workflow_paused_event_to_stream_responses():
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(task_id="task", workflow_run_id="run-id", workflow_id="workflow-id")
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="Rendered",
|
||||
inputs=[
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", placeholder=None),
|
||||
],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
node_id="node-id",
|
||||
node_title="Human Step",
|
||||
web_app_form_token="token",
|
||||
)
|
||||
queue_event = QueueWorkflowPausedEvent(
|
||||
reasons=[reason],
|
||||
outputs={"answer": "value"},
|
||||
paused_nodes=["node-id"],
|
||||
)
|
||||
|
||||
responses = converter.workflow_pause_to_stream_response(event=queue_event, task_id="task")
|
||||
|
||||
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
|
||||
pause_resp = responses[-1]
|
||||
assert pause_resp.workflow_run_id == "run-id"
|
||||
assert pause_resp.data.paused_nodes == ["node-id"]
|
||||
assert pause_resp.data.outputs == {"answer": "value"}
|
||||
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
|
||||
|
||||
assert isinstance(responses[0], HumanInputRequiredResponse)
|
||||
hi_resp = responses[0]
|
||||
assert hi_resp.data.form_id == "form-1"
|
||||
assert hi_resp.data.node_id == "node-id"
|
||||
assert hi_resp.data.node_title == "Human Step"
|
||||
assert hi_resp.data.inputs[0].output_variable_name == "field"
|
||||
assert hi_resp.data.actions[0].id == "approve"
|
||||
@ -0,0 +1,141 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import (
|
||||
WorkflowResumptionContext,
|
||||
_AdvancedChatAppGenerateEntityWrapper,
|
||||
_WorkflowGenerateEntityWrapper,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TraceQueueManagerStub(TraceQueueManager):
|
||||
"""Minimal TraceQueueManager stub that avoids Flask dependencies."""
|
||||
|
||||
def __init__(self):
|
||||
# Skip parent initialization to avoid starting timers or accessing Flask globals.
|
||||
pass
|
||||
|
||||
|
||||
def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=app_mode,
|
||||
workflow_id=f"{app_mode.value}-workflow-id",
|
||||
)
|
||||
|
||||
|
||||
def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity:
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id="workflow-task",
|
||||
app_config=_build_workflow_app_config(AppMode.WORKFLOW),
|
||||
inputs={"topic": "serialization"},
|
||||
files=[],
|
||||
user_id="user-workflow",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=1,
|
||||
trace_manager=trace_manager,
|
||||
workflow_execution_id="workflow-exec-id",
|
||||
extras={"external_trace_id": "trace-id"},
|
||||
)
|
||||
|
||||
|
||||
def _create_advanced_chat_generate_entity(trace_manager: TraceQueueManager | None = None) -> AdvancedChatAppGenerateEntity:
|
||||
return AdvancedChatAppGenerateEntity(
|
||||
task_id="advanced-task",
|
||||
app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT),
|
||||
conversation_id="conversation-id",
|
||||
inputs={"topic": "roundtrip"},
|
||||
files=[],
|
||||
user_id="user-advanced",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
query="Explain serialization",
|
||||
extras={"auto_generate_conversation_name": True},
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id="workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager():
|
||||
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
|
||||
serialized = entity.model_dump_json()
|
||||
payload = json.loads(serialized)
|
||||
|
||||
assert "trace_manager" not in payload
|
||||
|
||||
restored = WorkflowAppGenerateEntity.model_validate_json(serialized)
|
||||
|
||||
assert restored.model_dump() == entity.model_dump()
|
||||
assert restored.trace_manager is None
|
||||
|
||||
|
||||
def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager():
|
||||
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
|
||||
serialized = entity.model_dump_json()
|
||||
payload = json.loads(serialized)
|
||||
|
||||
assert "trace_manager" not in payload
|
||||
|
||||
restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized)
|
||||
|
||||
assert restored.model_dump() == entity.model_dump()
|
||||
assert restored.trace_manager is None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResumptionContextCase:
|
||||
name: str
|
||||
context_factory: Callable[[], tuple[WorkflowResumptionContext, type]]
|
||||
|
||||
|
||||
def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]:
|
||||
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
context = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=json.dumps({"state": "workflow"}),
|
||||
generate_entity=_WorkflowGenerateEntityWrapper(entity=entity),
|
||||
)
|
||||
return context, WorkflowAppGenerateEntity
|
||||
|
||||
|
||||
def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]:
|
||||
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
context = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=json.dumps({"state": "advanced"}),
|
||||
generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity),
|
||||
)
|
||||
return context, AdvancedChatAppGenerateEntity
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"),
|
||||
pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"),
|
||||
],
|
||||
)
|
||||
def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase):
|
||||
context, expected_type = case.context_factory()
|
||||
|
||||
serialized = context.dumps()
|
||||
restored = WorkflowResumptionContext.loads(serialized)
|
||||
|
||||
assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state
|
||||
entity = restored.get_generate_entity()
|
||||
assert isinstance(entity, expected_type)
|
||||
assert entity.model_dump() == context.get_generate_entity().model_dump()
|
||||
assert entity.trace_manager is None
|
||||
@ -34,6 +34,7 @@ class _InMemoryFormRecipient(HumanInputFormRecipientEntity):
|
||||
@dataclass
|
||||
class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
token: str | None = None
|
||||
|
||||
@property
|
||||
@ -48,6 +49,10 @@ class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
|
||||
class _InMemoryFormSubmission(FormSubmission):
|
||||
def __init__(self, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
|
||||
@ -76,7 +81,7 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
self.created_params.append(params)
|
||||
self._form_counter += 1
|
||||
form_id = f"form-{self._form_counter}"
|
||||
entity = _InMemoryFormEntity(form_id=form_id, token=f"token-{form_id}")
|
||||
entity = _InMemoryFormEntity(form_id=form_id, rendered=params.rendered_content, token=f"token-{form_id}")
|
||||
self.created_forms.append(entity)
|
||||
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity
|
||||
return entity
|
||||
|
||||
@ -248,6 +248,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_form_entity.rendered_content = "rendered"
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
|
||||
@ -193,6 +193,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_form_entity.rendered_content = "rendered"
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
|
||||
@ -52,6 +52,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
repo.get_form.return_value = form_entity
|
||||
return repo
|
||||
|
||||
@ -63,6 +64,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository:
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
repo.create_form.return_value = form_entity
|
||||
repo.get_form.return_value = None
|
||||
return repo
|
||||
|
||||
38
api/tests/unit_tests/tasks/test_workflow_execute_task.py
Normal file
38
api/tests/unit_tests/tasks/test_workflow_execute_task.py
Normal file
@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tasks.app_generate.workflow_execute_task import _publish_streaming_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_topic(mocker) -> MagicMock:
|
||||
topic = MagicMock()
|
||||
mocker.patch(
|
||||
"tasks.app_generate.workflow_execute_task.AdvancedChatAppGenerator.get_response_topic",
|
||||
return_value=topic,
|
||||
)
|
||||
return topic
|
||||
|
||||
|
||||
def test_publish_streaming_response_with_uuid(mock_topic: MagicMock):
|
||||
workflow_run_id = uuid.uuid4()
|
||||
response_stream = iter([{"event": "foo"}, "ping"])
|
||||
|
||||
_publish_streaming_response(response_stream, workflow_run_id)
|
||||
|
||||
payloads = [call.args[0] for call in mock_topic.publish.call_args_list]
|
||||
assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()]
|
||||
|
||||
|
||||
def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock):
|
||||
workflow_run_id = uuid.uuid4()
|
||||
response_stream = iter([{"event": "bar"}])
|
||||
|
||||
_publish_streaming_response(response_stream, str(workflow_run_id))
|
||||
|
||||
mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode())
|
||||
Reference in New Issue
Block a user