mirror of
https://github.com/langgenius/dify.git
synced 2026-06-18 05:38:17 +08:00
Compare commits
1 Commits
main
...
fix/dify-2
| Author | SHA1 | Date | |
|---|---|---|---|
| 75ad0e11b7 |
@ -72,13 +72,48 @@ def publish_text_answer(
|
||||
both the backend-produced answer and short-circuited answers (moderation /
|
||||
annotation reply) share the exact same persistence + SSE path.
|
||||
"""
|
||||
publish_text_delta(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
delta=answer,
|
||||
user_query=user_query,
|
||||
)
|
||||
publish_message_end(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
answer=answer,
|
||||
user_query=user_query,
|
||||
)
|
||||
|
||||
|
||||
def publish_text_delta(
|
||||
*,
|
||||
queue_manager: AppQueueManager,
|
||||
model_name: str,
|
||||
delta: str,
|
||||
user_query: str | None = None,
|
||||
) -> None:
|
||||
"""Publish one assistant text delta through the EasyUI chat pipeline."""
|
||||
if not delta:
|
||||
return
|
||||
prompt_messages = _prompt_messages_from_query(user_query)
|
||||
chunk = LLMResultChunk(
|
||||
model=model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)),
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=delta)),
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
|
||||
def publish_message_end(
|
||||
*,
|
||||
queue_manager: AppQueueManager,
|
||||
model_name: str,
|
||||
answer: str,
|
||||
user_query: str | None = None,
|
||||
) -> None:
|
||||
"""Publish the terminal assistant result without emitting another delta."""
|
||||
prompt_messages = _prompt_messages_from_query(user_query)
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
@ -151,7 +186,12 @@ class AgentAppRunner:
|
||||
)
|
||||
|
||||
create_response = self._agent_backend_client.create_run(runtime.request)
|
||||
terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager)
|
||||
terminal, streamed_answer = self._consume_stream(
|
||||
create_response.run_id,
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
query=query,
|
||||
)
|
||||
|
||||
if isinstance(terminal, AgentBackendDeferredToolCallInternalEvent):
|
||||
# ENG-635: the agent asked a human. End this turn with the question and
|
||||
@ -175,7 +215,13 @@ class AgentAppRunner:
|
||||
raise AgentBackendError(str(error))
|
||||
|
||||
answer = self._extract_answer(terminal.output)
|
||||
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query)
|
||||
self._publish_terminal_answer(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
answer=answer,
|
||||
query=query,
|
||||
streamed_answer=streamed_answer,
|
||||
)
|
||||
self._save_session(
|
||||
scope=scope,
|
||||
backend_run_id=terminal.run_id,
|
||||
@ -272,8 +318,16 @@ class AgentAppRunner:
|
||||
parts.append(args.markdown)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager):
|
||||
def _consume_stream(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
queue_manager: AppQueueManager,
|
||||
model_name: str,
|
||||
query: str | None,
|
||||
):
|
||||
terminal = None
|
||||
streamed_answer_parts: list[str] = []
|
||||
for public_event in self._agent_backend_client.stream_events(run_id):
|
||||
if queue_manager.is_stopped():
|
||||
self._cancel_run(run_id)
|
||||
@ -286,16 +340,23 @@ class AgentAppRunner:
|
||||
AgentBackendInternalEventType.RUN_STARTED,
|
||||
AgentBackendInternalEventType.STREAM_EVENT,
|
||||
):
|
||||
# Stream deltas are accumulated by the backend into the
|
||||
# terminal output; token-level forwarding is an S3 refinement.
|
||||
if isinstance(internal_event, AgentBackendStreamInternalEvent):
|
||||
text_delta = self._extract_stream_text_delta(internal_event)
|
||||
if text_delta:
|
||||
streamed_answer_parts.append(text_delta)
|
||||
publish_text_delta(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
delta=text_delta,
|
||||
user_query=query,
|
||||
)
|
||||
continue
|
||||
continue
|
||||
terminal = internal_event
|
||||
break
|
||||
if terminal is not None:
|
||||
break
|
||||
return terminal
|
||||
return terminal, "".join(streamed_answer_parts)
|
||||
|
||||
def _cancel_run(self, run_id: str) -> None:
|
||||
try:
|
||||
@ -310,6 +371,35 @@ class AgentAppRunner:
|
||||
# task pipeline streams the chunk over SSE and persists the message.
|
||||
publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query)
|
||||
|
||||
def _publish_terminal_answer(
|
||||
self,
|
||||
*,
|
||||
queue_manager: AppQueueManager,
|
||||
model_name: str,
|
||||
answer: str,
|
||||
query: str | None,
|
||||
streamed_answer: str,
|
||||
) -> None:
|
||||
"""Finish a successful streamed turn without duplicating the final text."""
|
||||
if not streamed_answer:
|
||||
self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query)
|
||||
return
|
||||
|
||||
if answer.startswith(streamed_answer):
|
||||
publish_text_delta(
|
||||
queue_manager=queue_manager,
|
||||
model_name=model_name,
|
||||
delta=answer[len(streamed_answer) :],
|
||||
user_query=query,
|
||||
)
|
||||
elif answer != streamed_answer:
|
||||
logger.warning(
|
||||
"Agent App streamed answer does not match terminal output; "
|
||||
"using terminal output for message persistence."
|
||||
)
|
||||
|
||||
publish_message_end(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query)
|
||||
|
||||
def _save_session(
|
||||
self,
|
||||
*,
|
||||
@ -357,5 +447,27 @@ class AgentAppRunner:
|
||||
return json.dumps(output, ensure_ascii=False)
|
||||
return json.dumps(output, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _extract_stream_text_delta(event: AgentBackendStreamInternalEvent) -> str | None:
|
||||
data = event.data
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
__all__ = ["AgentAppRunner", "publish_text_answer"]
|
||||
if data.get("event_kind") == "part_delta":
|
||||
delta = data.get("delta")
|
||||
if isinstance(delta, dict) and delta.get("part_delta_kind") == "text":
|
||||
content_delta = delta.get("content_delta")
|
||||
if isinstance(content_delta, str):
|
||||
return content_delta
|
||||
|
||||
if data.get("event_kind") == "part_start":
|
||||
part = data.get("part")
|
||||
if isinstance(part, dict) and part.get("part_kind") == "text":
|
||||
content = part.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["AgentAppRunner", "publish_message_end", "publish_text_answer", "publish_text_delta"]
|
||||
|
||||
@ -4,6 +4,8 @@ saved, using the deterministic fake backend client (no live stack)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, override
|
||||
from unittest.mock import MagicMock
|
||||
@ -11,7 +13,17 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.layers.ask_human import AskHumanToolResult
|
||||
from dify_agent.protocol import CancelRunRequest, CancelRunResponse, RuntimeLayerSpec
|
||||
from dify_agent.protocol import (
|
||||
CancelRunRequest,
|
||||
CancelRunResponse,
|
||||
PydanticAIStreamRunEvent,
|
||||
RunEvent,
|
||||
RunStartedEvent,
|
||||
RunSucceededEvent,
|
||||
RunSucceededEventData,
|
||||
RuntimeLayerSpec,
|
||||
)
|
||||
from pydantic_ai.messages import PartDeltaEvent, PartStartEvent, TextPart, TextPartDelta
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendError,
|
||||
@ -67,6 +79,58 @@ class _RecordingFakeAgentBackendRunClient(FakeAgentBackendRunClient):
|
||||
return super().cancel_run(run_id, request=request)
|
||||
|
||||
|
||||
class _StreamingFakeAgentBackendRunClient(FakeAgentBackendRunClient):
|
||||
@override
|
||||
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
del after
|
||||
created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
yield RunStartedEvent(id="1-0", run_id=run_id, created_at=created_at)
|
||||
yield PydanticAIStreamRunEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=created_at,
|
||||
data=PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="hello ")),
|
||||
)
|
||||
yield PydanticAIStreamRunEvent(
|
||||
id="3-0",
|
||||
run_id=run_id,
|
||||
created_at=created_at,
|
||||
data=PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="agent")),
|
||||
)
|
||||
yield RunSucceededEvent(
|
||||
id="4-0",
|
||||
run_id=run_id,
|
||||
created_at=created_at,
|
||||
data=RunSucceededEventData(
|
||||
output={"text": "hello agent"},
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[]),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class _StreamingPartStartFakeAgentBackendRunClient(FakeAgentBackendRunClient):
|
||||
@override
|
||||
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
del after
|
||||
created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
yield RunStartedEvent(id="1-0", run_id=run_id, created_at=created_at)
|
||||
yield PydanticAIStreamRunEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=created_at,
|
||||
data=PartStartEvent(index=0, part=TextPart(content="hello")),
|
||||
)
|
||||
yield RunSucceededEvent(
|
||||
id="3-0",
|
||||
run_id=run_id,
|
||||
created_at=created_at,
|
||||
data=RunSucceededEventData(
|
||||
output={"text": "hello agent"},
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[]),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class _FakeSessionStore:
|
||||
def __init__(
|
||||
self,
|
||||
@ -165,9 +229,13 @@ def _message_end(qm: _FakeQueueManager) -> QueueMessageEndEvent:
|
||||
|
||||
|
||||
def _saved_user_query(qm: _FakeQueueManager) -> str:
|
||||
prompt_messages = _message_end(qm).llm_result.prompt_messages
|
||||
llm_result = _message_end(qm).llm_result
|
||||
assert llm_result is not None
|
||||
prompt_messages = llm_result.prompt_messages
|
||||
assert len(prompt_messages) == 1
|
||||
return prompt_messages[0].content
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
return content
|
||||
|
||||
|
||||
def test_successful_turn_publishes_chunk_and_message_end_and_saves_session():
|
||||
@ -204,6 +272,35 @@ def test_successful_turn_publishes_chunk_and_message_end_and_saves_session():
|
||||
]
|
||||
|
||||
|
||||
def test_successful_turn_forwards_agent_backend_stream_text_deltas_without_duplicate_terminal_chunk():
|
||||
client = _StreamingFakeAgentBackendRunClient()
|
||||
store = _FakeSessionStore()
|
||||
qm = _FakeQueueManager()
|
||||
|
||||
_run(_runner(client, store), qm)
|
||||
|
||||
chunk_events = [e for e in qm.events if isinstance(e, QueueLLMChunkEvent)]
|
||||
end_events = [e for e in qm.events if isinstance(e, QueueMessageEndEvent)]
|
||||
assert [event.chunk.delta.message.content for event in chunk_events] == ["hello ", "agent"]
|
||||
assert len(end_events) == 1
|
||||
assert end_events[0].llm_result.message.content == "hello agent"
|
||||
assert store.saved
|
||||
|
||||
|
||||
def test_successful_turn_forwards_part_start_text_and_publishes_missing_terminal_suffix():
|
||||
client = _StreamingPartStartFakeAgentBackendRunClient()
|
||||
store = _FakeSessionStore()
|
||||
qm = _FakeQueueManager()
|
||||
|
||||
_run(_runner(client, store), qm)
|
||||
|
||||
chunk_events = [e for e in qm.events if isinstance(e, QueueLLMChunkEvent)]
|
||||
end_events = [e for e in qm.events if isinstance(e, QueueMessageEndEvent)]
|
||||
assert [event.chunk.delta.message.content for event in chunk_events] == ["hello", " agent"]
|
||||
assert len(end_events) == 1
|
||||
assert end_events[0].llm_result.message.content == "hello agent"
|
||||
|
||||
|
||||
def test_prior_session_snapshot_is_threaded_into_request():
|
||||
prior = CompositorSessionSnapshot(layers=[])
|
||||
client = FakeAgentBackendRunClient()
|
||||
|
||||
Reference in New Issue
Block a user