Compare commits

..

1 Commits

Author SHA1 Message Date
75ad0e11b7 fix: stream agent app backend deltas 2026-06-18 01:47:46 +08:00
2 changed files with 220 additions and 11 deletions

View File

@ -72,13 +72,48 @@ def publish_text_answer(
both the backend-produced answer and short-circuited answers (moderation /
annotation reply) share the exact same persistence + SSE path.
"""
publish_text_delta(
queue_manager=queue_manager,
model_name=model_name,
delta=answer,
user_query=user_query,
)
publish_message_end(
queue_manager=queue_manager,
model_name=model_name,
answer=answer,
user_query=user_query,
)
def publish_text_delta(
*,
queue_manager: AppQueueManager,
model_name: str,
delta: str,
user_query: str | None = None,
) -> None:
"""Publish one assistant text delta through the EasyUI chat pipeline."""
if not delta:
return
prompt_messages = _prompt_messages_from_query(user_query)
chunk = LLMResultChunk(
model=model_name,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)),
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=delta)),
)
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
def publish_message_end(
*,
queue_manager: AppQueueManager,
model_name: str,
answer: str,
user_query: str | None = None,
) -> None:
"""Publish the terminal assistant result without emitting another delta."""
prompt_messages = _prompt_messages_from_query(user_query)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
@ -151,7 +186,12 @@ class AgentAppRunner:
)
create_response = self._agent_backend_client.create_run(runtime.request)
terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager)
terminal, streamed_answer = self._consume_stream(
create_response.run_id,
queue_manager=queue_manager,
model_name=model_name,
query=query,
)
if isinstance(terminal, AgentBackendDeferredToolCallInternalEvent):
# ENG-635: the agent asked a human. End this turn with the question and
@ -175,7 +215,13 @@ class AgentAppRunner:
raise AgentBackendError(str(error))
answer = self._extract_answer(terminal.output)
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query)
self._publish_terminal_answer(
queue_manager=queue_manager,
model_name=model_name,
answer=answer,
query=query,
streamed_answer=streamed_answer,
)
self._save_session(
scope=scope,
backend_run_id=terminal.run_id,
@ -272,8 +318,16 @@ class AgentAppRunner:
parts.append(args.markdown)
return "\n\n".join(parts)
def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager):
def _consume_stream(
self,
run_id: str,
*,
queue_manager: AppQueueManager,
model_name: str,
query: str | None,
):
terminal = None
streamed_answer_parts: list[str] = []
for public_event in self._agent_backend_client.stream_events(run_id):
if queue_manager.is_stopped():
self._cancel_run(run_id)
@ -286,16 +340,23 @@ class AgentAppRunner:
AgentBackendInternalEventType.RUN_STARTED,
AgentBackendInternalEventType.STREAM_EVENT,
):
# Stream deltas are accumulated by the backend into the
# terminal output; token-level forwarding is an S3 refinement.
if isinstance(internal_event, AgentBackendStreamInternalEvent):
text_delta = self._extract_stream_text_delta(internal_event)
if text_delta:
streamed_answer_parts.append(text_delta)
publish_text_delta(
queue_manager=queue_manager,
model_name=model_name,
delta=text_delta,
user_query=query,
)
continue
continue
terminal = internal_event
break
if terminal is not None:
break
return terminal
return terminal, "".join(streamed_answer_parts)
def _cancel_run(self, run_id: str) -> None:
try:
@ -310,6 +371,35 @@ class AgentAppRunner:
# task pipeline streams the chunk over SSE and persists the message.
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query)
def _publish_terminal_answer(
self,
*,
queue_manager: AppQueueManager,
model_name: str,
answer: str,
query: str | None,
streamed_answer: str,
) -> None:
"""Finish a successful streamed turn without duplicating the final text."""
if not streamed_answer:
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query)
return
if answer.startswith(streamed_answer):
publish_text_delta(
queue_manager=queue_manager,
model_name=model_name,
delta=answer[len(streamed_answer) :],
user_query=query,
)
elif answer != streamed_answer:
logger.warning(
"Agent App streamed answer does not match terminal output; "
"using terminal output for message persistence."
)
publish_message_end(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query)
def _save_session(
self,
*,
@ -357,5 +447,27 @@ class AgentAppRunner:
return json.dumps(output, ensure_ascii=False)
return json.dumps(output, ensure_ascii=False)
@staticmethod
def _extract_stream_text_delta(event: AgentBackendStreamInternalEvent) -> str | None:
data = event.data
if not isinstance(data, dict):
return None
__all__ = ["AgentAppRunner", "publish_text_answer"]
if data.get("event_kind") == "part_delta":
delta = data.get("delta")
if isinstance(delta, dict) and delta.get("part_delta_kind") == "text":
content_delta = delta.get("content_delta")
if isinstance(content_delta, str):
return content_delta
if data.get("event_kind") == "part_start":
part = data.get("part")
if isinstance(part, dict) and part.get("part_kind") == "text":
content = part.get("content")
if isinstance(content, str):
return content
return None
__all__ = ["AgentAppRunner", "publish_message_end", "publish_text_answer", "publish_text_delta"]

View File

@ -4,6 +4,8 @@ saved, using the deterministic fake backend client (no live stack)."""
from __future__ import annotations
from collections.abc import Iterator
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, override
from unittest.mock import MagicMock
@ -11,7 +13,17 @@ from unittest.mock import MagicMock
import pytest
from agenton.compositor import CompositorSessionSnapshot
from dify_agent.layers.ask_human import AskHumanToolResult
from dify_agent.protocol import CancelRunRequest, CancelRunResponse, RuntimeLayerSpec
from dify_agent.protocol import (
CancelRunRequest,
CancelRunResponse,
PydanticAIStreamRunEvent,
RunEvent,
RunStartedEvent,
RunSucceededEvent,
RunSucceededEventData,
RuntimeLayerSpec,
)
from pydantic_ai.messages import PartDeltaEvent, PartStartEvent, TextPart, TextPartDelta
from clients.agent_backend import (
AgentBackendError,
@ -67,6 +79,58 @@ class _RecordingFakeAgentBackendRunClient(FakeAgentBackendRunClient):
return super().cancel_run(run_id, request=request)
class _StreamingFakeAgentBackendRunClient(FakeAgentBackendRunClient):
@override
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
del after
created_at = datetime(2026, 1, 1, tzinfo=UTC)
yield RunStartedEvent(id="1-0", run_id=run_id, created_at=created_at)
yield PydanticAIStreamRunEvent(
id="2-0",
run_id=run_id,
created_at=created_at,
data=PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="hello ")),
)
yield PydanticAIStreamRunEvent(
id="3-0",
run_id=run_id,
created_at=created_at,
data=PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="agent")),
)
yield RunSucceededEvent(
id="4-0",
run_id=run_id,
created_at=created_at,
data=RunSucceededEventData(
output={"text": "hello agent"},
session_snapshot=CompositorSessionSnapshot(layers=[]),
),
)
class _StreamingPartStartFakeAgentBackendRunClient(FakeAgentBackendRunClient):
@override
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
del after
created_at = datetime(2026, 1, 1, tzinfo=UTC)
yield RunStartedEvent(id="1-0", run_id=run_id, created_at=created_at)
yield PydanticAIStreamRunEvent(
id="2-0",
run_id=run_id,
created_at=created_at,
data=PartStartEvent(index=0, part=TextPart(content="hello")),
)
yield RunSucceededEvent(
id="3-0",
run_id=run_id,
created_at=created_at,
data=RunSucceededEventData(
output={"text": "hello agent"},
session_snapshot=CompositorSessionSnapshot(layers=[]),
),
)
class _FakeSessionStore:
def __init__(
self,
@ -165,9 +229,13 @@ def _message_end(qm: _FakeQueueManager) -> QueueMessageEndEvent:
def _saved_user_query(qm: _FakeQueueManager) -> str:
prompt_messages = _message_end(qm).llm_result.prompt_messages
llm_result = _message_end(qm).llm_result
assert llm_result is not None
prompt_messages = llm_result.prompt_messages
assert len(prompt_messages) == 1
return prompt_messages[0].content
content = prompt_messages[0].content
assert isinstance(content, str)
return content
def test_successful_turn_publishes_chunk_and_message_end_and_saves_session():
@ -204,6 +272,35 @@ def test_successful_turn_publishes_chunk_and_message_end_and_saves_session():
]
def test_successful_turn_forwards_agent_backend_stream_text_deltas_without_duplicate_terminal_chunk():
client = _StreamingFakeAgentBackendRunClient()
store = _FakeSessionStore()
qm = _FakeQueueManager()
_run(_runner(client, store), qm)
chunk_events = [e for e in qm.events if isinstance(e, QueueLLMChunkEvent)]
end_events = [e for e in qm.events if isinstance(e, QueueMessageEndEvent)]
assert [event.chunk.delta.message.content for event in chunk_events] == ["hello ", "agent"]
assert len(end_events) == 1
assert end_events[0].llm_result.message.content == "hello agent"
assert store.saved
def test_successful_turn_forwards_part_start_text_and_publishes_missing_terminal_suffix():
client = _StreamingPartStartFakeAgentBackendRunClient()
store = _FakeSessionStore()
qm = _FakeQueueManager()
_run(_runner(client, store), qm)
chunk_events = [e for e in qm.events if isinstance(e, QueueLLMChunkEvent)]
end_events = [e for e in qm.events if isinstance(e, QueueMessageEndEvent)]
assert [event.chunk.delta.message.content for event in chunk_events] == ["hello", " agent"]
assert len(end_events) == 1
assert end_events[0].llm_result.message.content == "hello agent"
def test_prior_session_snapshot_is_threaded_into_request():
prior = CompositorSessionSnapshot(layers=[])
client = FakeAgentBackendRunClient()