feat(api): implement paused status exposure for service api

This commit is contained in:
QuantumGhost
2026-01-27 09:26:16 +08:00
parent 116ec9dd04
commit 7c33e5107b
10 changed files with 155 additions and 12 deletions

View File

@ -33,7 +33,7 @@ from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
from libs.helper import TimestampField
from libs.helper import OptionalTimestampField, TimestampField
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
@ -63,17 +63,49 @@ class WorkflowLogQuery(BaseModel):
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
class WorkflowRunStatusField(fields.Raw):
def output(self, key, obj, **kwargs):
status = getattr(obj, "status", None)
if hasattr(status, "value"):
return status.value
if isinstance(obj, dict):
value = obj.get(key) or obj.get("status")
if hasattr(value, "value"):
return value.value
return value
return status
class WorkflowRunOutputsField(fields.Raw):
def output(self, key, obj, **kwargs):
status = getattr(obj, "status", None)
status_value = status.value if hasattr(status, "value") else status
if status_value == WorkflowExecutionStatus.PAUSED.value:
return {}
outputs = getattr(obj, "outputs_dict", None)
if outputs is not None:
return outputs or {}
if isinstance(obj, dict):
value = obj.get(key) or obj.get("outputs")
return value or {}
return {}
workflow_run_fields = {
"id": fields.String,
"workflow_id": fields.String,
"status": fields.String,
"status": WorkflowRunStatusField,
"inputs": fields.Raw,
"outputs": fields.Raw,
"outputs": WorkflowRunOutputsField,
"error": fields.String,
"total_steps": fields.Integer,
"total_tokens": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField,
"finished_at": OptionalTimestampField,
"elapsed_time": fields.Float,
}

View File

@ -546,9 +546,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):

View File

@ -282,9 +282,19 @@ class WorkflowResponseConverter:
*,
event: QueueWorkflowPausedEvent,
task_id: str,
graph_runtime_state: GraphRuntimeState,
) -> list[StreamResponse]:
run_id = self._ensure_workflow_run_id()
started_at = self._workflow_started_at
if started_at is None:
raise ValueError(
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
)
paused_at = naive_utc_now()
elapsed_time = (paused_at - started_at).total_seconds()
encoded_outputs = self._encode_outputs(event.outputs) or {}
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
encoded_outputs = {}
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
responses: list[StreamResponse] = []
@ -318,6 +328,11 @@ class WorkflowResponseConverter:
paused_nodes=list(event.paused_nodes),
outputs=encoded_outputs,
reasons=pause_reasons,
status=WorkflowExecutionStatus.PAUSED.value,
created_at=int(started_at.timestamp()),
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
),
)
)

View File

@ -10,7 +10,6 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
@ -138,7 +137,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowPauseStreamResponse):
raise WorkflowPausedInBlockingModeError()
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
outputs=stream_response.data.outputs or {},
error=None,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@ -455,9 +471,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
yield from responses

View File

@ -239,7 +239,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
total_steps: int
created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int
finished_at: int
finished_at: int | None
exceptions_count: int | None = 0
files: Sequence[Mapping[str, Any]] | None = []
@ -262,6 +262,11 @@ class WorkflowPauseStreamResponse(StreamResponse):
paused_nodes: Sequence[str] = Field(default_factory=list)
outputs: Mapping[str, Any] = Field(default_factory=dict)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
status: str
created_at: int
elapsed_time: float
total_tokens: int
total_steps: int
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
workflow_run_id: str

View File

@ -94,6 +94,13 @@ class TimestampField(fields.Raw):
return int(value.timestamp())
class OptionalTimestampField(fields.Raw):
def format(self, value) -> int | None:
if value is None:
return None
return int(value.timestamp())
def email(email):
# Define a regex pattern for email addresses
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"

View File

@ -0,0 +1,39 @@
from types import SimpleNamespace
from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField
from core.workflow.enums import WorkflowExecutionStatus
def test_workflow_run_status_field_with_enum() -> None:
field = WorkflowRunStatusField()
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED)
assert field.output("status", obj) == "paused"
def test_workflow_run_status_field_with_dict() -> None:
field = WorkflowRunStatusField()
payload = {"status": "running"}
assert field.output("status", payload) == "running"
def test_workflow_run_outputs_field_paused_returns_empty() -> None:
field = WorkflowRunOutputsField()
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED, outputs_dict={"foo": "bar"})
assert field.output("outputs", obj) == {}
def test_workflow_run_outputs_field_running_returns_outputs() -> None:
field = WorkflowRunOutputsField()
obj = SimpleNamespace(status=WorkflowExecutionStatus.RUNNING, outputs_dict={"foo": "bar"})
assert field.output("outputs", obj) == {"foo": "bar"}
def test_workflow_run_outputs_field_dict_fallback() -> None:
field = WorkflowRunOutputsField()
payload = {"status": "succeeded", "outputs": {"answer": "ok"}}
assert field.output("outputs", payload) == {"answer": "ok"}

View File

@ -101,7 +101,12 @@ def test_handle_workflow_paused_event_persists_human_input_extra_content() -> No
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
pipeline._workflow_response_converter = mock.Mock()
pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = []
pipeline._ensure_graph_runtime_initialized = mock.Mock(side_effect=ValueError())
pipeline._ensure_graph_runtime_initialized = mock.Mock(
return_value=SimpleNamespace(
total_tokens=0,
node_run_steps=0,
),
)
pipeline._save_message = mock.Mock()
message = SimpleNamespace(status=MessageStatus.NORMAL)
pipeline._get_message = mock.Mock(return_value=message)
@ -155,7 +160,7 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
answer="before",
status=MessageStatus.PAUSED,
)
user = EndUser.__new__(EndUser)
user = EndUser()
user.id = "user-1"
user.session_id = "session-1"
workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={})

View File

@ -138,13 +138,18 @@ def test_queue_workflow_paused_event_to_stream_responses():
paused_nodes=["node-id"],
)
responses = converter.workflow_pause_to_stream_response(event=queue_event, task_id="task")
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
responses = converter.workflow_pause_to_stream_response(
event=queue_event,
task_id="task",
graph_runtime_state=runtime_state,
)
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
pause_resp = responses[-1]
assert pause_resp.workflow_run_id == "run-id"
assert pause_resp.data.paused_nodes == ["node-id"]
assert pause_resp.data.outputs == {"answer": "value"}
assert pause_resp.data.outputs == {}
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
assert pause_resp.data.reasons[0]["display_in_ui"] is True

View File

@ -1,6 +1,8 @@
from datetime import datetime
import pytest
from libs.helper import extract_tenant_id
from libs.helper import OptionalTimestampField, extract_tenant_id
from models.account import Account
from models.model import EndUser
@ -63,3 +65,16 @@ class TestExtractTenantId:
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(dict_user)
class TestOptionalTimestampField:
def test_format_returns_none_for_none(self):
field = OptionalTimestampField()
assert field.format(None) is None
def test_format_returns_unix_timestamp_for_datetime(self):
field = OptionalTimestampField()
value = datetime(2024, 1, 2, 3, 4, 5)
assert field.format(value) == int(value.timestamp())