mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
feat(api): implement paused status exposure for service api
This commit is contained in:
@ -33,7 +33,7 @@ from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -63,17 +63,49 @@ class WorkflowLogQuery(BaseModel):
|
||||
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
def output(self, key, obj, **kwargs):
|
||||
status = getattr(obj, "status", None)
|
||||
if hasattr(status, "value"):
|
||||
return status.value
|
||||
if isinstance(obj, dict):
|
||||
value = obj.get(key) or obj.get("status")
|
||||
if hasattr(value, "value"):
|
||||
return value.value
|
||||
return value
|
||||
return status
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
def output(self, key, obj, **kwargs):
|
||||
status = getattr(obj, "status", None)
|
||||
status_value = status.value if hasattr(status, "value") else status
|
||||
if status_value == WorkflowExecutionStatus.PAUSED.value:
|
||||
return {}
|
||||
|
||||
outputs = getattr(obj, "outputs_dict", None)
|
||||
if outputs is not None:
|
||||
return outputs or {}
|
||||
|
||||
if isinstance(obj, dict):
|
||||
value = obj.get(key) or obj.get("outputs")
|
||||
return value or {}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": fields.String,
|
||||
"status": WorkflowRunStatusField,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": fields.Raw,
|
||||
"outputs": WorkflowRunOutputsField,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": TimestampField,
|
||||
"finished_at": OptionalTimestampField,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
|
||||
|
||||
@ -546,9 +546,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
for reason in event.reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
|
||||
@ -282,9 +282,19 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
task_id: str,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> list[StreamResponse]:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
started_at = self._workflow_started_at
|
||||
if started_at is None:
|
||||
raise ValueError(
|
||||
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
|
||||
)
|
||||
paused_at = naive_utc_now()
|
||||
elapsed_time = (paused_at - started_at).total_seconds()
|
||||
encoded_outputs = self._encode_outputs(event.outputs) or {}
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||
encoded_outputs = {}
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
@ -318,6 +328,11 @@ class WorkflowResponseConverter:
|
||||
paused_nodes=list(event.paused_nodes),
|
||||
outputs=encoded_outputs,
|
||||
reasons=pause_reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED.value,
|
||||
created_at=int(started_at.timestamp()),
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@ -10,7 +10,6 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -138,7 +137,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
raise WorkflowPausedInBlockingModeError()
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=stream_response.data.workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs or {},
|
||||
error=None,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=stream_response.data.created_at,
|
||||
finished_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -455,9 +471,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
yield from responses
|
||||
|
||||
|
||||
@ -239,7 +239,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
total_steps: int
|
||||
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
finished_at: int
|
||||
finished_at: int | None
|
||||
exceptions_count: int | None = 0
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
|
||||
@ -262,6 +262,11 @@ class WorkflowPauseStreamResponse(StreamResponse):
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
status: str
|
||||
created_at: int
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
|
||||
workflow_run_id: str
|
||||
|
||||
@ -94,6 +94,13 @@ class TimestampField(fields.Raw):
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class OptionalTimestampField(fields.Raw):
|
||||
def format(self, value) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def email(email):
|
||||
# Define a regex pattern for email addresses
|
||||
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
|
||||
|
||||
def test_workflow_run_status_field_with_enum() -> None:
|
||||
field = WorkflowRunStatusField()
|
||||
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED)
|
||||
|
||||
assert field.output("status", obj) == "paused"
|
||||
|
||||
|
||||
def test_workflow_run_status_field_with_dict() -> None:
|
||||
field = WorkflowRunStatusField()
|
||||
payload = {"status": "running"}
|
||||
|
||||
assert field.output("status", payload) == "running"
|
||||
|
||||
|
||||
def test_workflow_run_outputs_field_paused_returns_empty() -> None:
|
||||
field = WorkflowRunOutputsField()
|
||||
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED, outputs_dict={"foo": "bar"})
|
||||
|
||||
assert field.output("outputs", obj) == {}
|
||||
|
||||
|
||||
def test_workflow_run_outputs_field_running_returns_outputs() -> None:
|
||||
field = WorkflowRunOutputsField()
|
||||
obj = SimpleNamespace(status=WorkflowExecutionStatus.RUNNING, outputs_dict={"foo": "bar"})
|
||||
|
||||
assert field.output("outputs", obj) == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_workflow_run_outputs_field_dict_fallback() -> None:
|
||||
field = WorkflowRunOutputsField()
|
||||
payload = {"status": "succeeded", "outputs": {"answer": "ok"}}
|
||||
|
||||
assert field.output("outputs", payload) == {"answer": "ok"}
|
||||
@ -101,7 +101,12 @@ def test_handle_workflow_paused_event_persists_human_input_extra_content() -> No
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = []
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(side_effect=ValueError())
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(
|
||||
return_value=SimpleNamespace(
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
),
|
||||
)
|
||||
pipeline._save_message = mock.Mock()
|
||||
message = SimpleNamespace(status=MessageStatus.NORMAL)
|
||||
pipeline._get_message = mock.Mock(return_value=message)
|
||||
@ -155,7 +160,7 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
|
||||
answer="before",
|
||||
status=MessageStatus.PAUSED,
|
||||
)
|
||||
user = EndUser.__new__(EndUser)
|
||||
user = EndUser()
|
||||
user.id = "user-1"
|
||||
user.session_id = "session-1"
|
||||
workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={})
|
||||
|
||||
@ -138,13 +138,18 @@ def test_queue_workflow_paused_event_to_stream_responses():
|
||||
paused_nodes=["node-id"],
|
||||
)
|
||||
|
||||
responses = converter.workflow_pause_to_stream_response(event=queue_event, task_id="task")
|
||||
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
|
||||
responses = converter.workflow_pause_to_stream_response(
|
||||
event=queue_event,
|
||||
task_id="task",
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
|
||||
pause_resp = responses[-1]
|
||||
assert pause_resp.workflow_run_id == "run-id"
|
||||
assert pause_resp.data.paused_nodes == ["node-id"]
|
||||
assert pause_resp.data.outputs == {"answer": "value"}
|
||||
assert pause_resp.data.outputs == {}
|
||||
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
|
||||
assert pause_resp.data.reasons[0]["display_in_ui"] is True
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.helper import extract_tenant_id
|
||||
from libs.helper import OptionalTimestampField, extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
@ -63,3 +65,16 @@ class TestExtractTenantId:
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(dict_user)
|
||||
|
||||
|
||||
class TestOptionalTimestampField:
|
||||
def test_format_returns_none_for_none(self):
|
||||
field = OptionalTimestampField()
|
||||
|
||||
assert field.format(None) is None
|
||||
|
||||
def test_format_returns_unix_timestamp_for_datetime(self):
|
||||
field = OptionalTimestampField()
|
||||
value = datetime(2024, 1, 2, 3, 4, 5)
|
||||
|
||||
assert field.format(value) == int(value.timestamp())
|
||||
|
||||
Reference in New Issue
Block a user