fix(api): resolve multi-turn memory failure in Agent apps

- Auto-resolve parent_message_id when not provided by client,
  querying the latest message in the conversation to maintain
  the thread chain that extract_thread_messages() relies on.
- Add AppMode.AGENT to TokenBufferMemory mode checks so file
  attachments in memory are handled via the workflow branch.
- Add debug logging for memory injection in node_factory and node.

Made-with: Cursor
This commit is contained in:
Yansong Zhang
2026-04-09 16:27:38 +08:00
parent e2e16772a1
commit 2de2a8fd3a
4 changed files with 99 additions and 16 deletions

View File

@ -177,6 +177,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
# Resolve parent_message_id for thread continuity
if invoke_from == InvokeFrom.SERVICE_API:
parent_message_id: str | None = UUID_NIL
else:
parent_message_id = args.get("parent_message_id")
if not parent_message_id and conversation:
parent_message_id = self._resolve_latest_message_id(conversation.id)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
@ -188,7 +196,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=parent_message_id,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
@ -689,3 +697,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e
@staticmethod
def _resolve_latest_message_id(conversation_id: str) -> str | None:
"""Auto-resolve parent_message_id to the latest message when client doesn't provide one."""
from sqlalchemy import select
stmt = (
select(Message.id)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at.desc())
.limit(1)
)
latest_id = db.session.scalar(stmt)
return str(latest_id) if latest_id else None

View File

@ -63,7 +63,7 @@ class TokenBufferMemory:
"""
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT}:
app = self.conversation.app
if not app:
raise ValueError("App not found for conversation")

View File

@ -402,14 +402,7 @@ class DifyNodeFactory(NodeFactory):
"runtime_support": self._agent_runtime_support,
"message_transformer": self._agent_message_transformer,
},
AGENT_V2_NODE_TYPE: lambda: {
"tool_manager": AgentV2ToolManager(
tenant_id=self._dify_context.tenant_id,
app_id=self._dify_context.app_id,
),
"event_adapter": AgentV2EventAdapter(),
"sandbox": self._resolve_sandbox(),
},
AGENT_V2_NODE_TYPE: lambda: self._build_agent_v2_kwargs(node_data),
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
@ -420,6 +413,55 @@ class DifyNodeFactory(NodeFactory):
**node_init_kwargs,
)
def _build_agent_v2_kwargs(self, node_data: BaseNodeData) -> dict[str, object]:
"""Build initialization kwargs for Agent V2 node.
Injects memory (same mechanism as LLM Node) plus tool_manager,
event_adapter, and sandbox.
"""
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
validated = AgentV2NodeData.model_validate(node_data.model_dump())
import logging as _logging
_log = _logging.getLogger(__name__)
memory = None
if validated.memory is not None:
conversation_id = get_system_text(
self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID
)
_log.info("[AGENT_V2_MEMORY] memory_config=%s, conversation_id=%s", validated.memory, conversation_id)
if conversation_id:
from graphon.model_runtime.entities.model_entities import ModelType as _ModelType
from core.model_manager import ModelManager as _ModelManager
model_instance = _ModelManager.for_tenant(
tenant_id=self._dify_context.tenant_id
).get_model_instance(
tenant_id=self._dify_context.tenant_id,
provider=validated.model.provider,
model_type=_ModelType.LLM,
model=validated.model.name,
)
memory = fetch_memory(
conversation_id=conversation_id,
app_id=self._dify_context.app_id,
node_data_memory=validated.memory,
model_instance=model_instance,
)
return {
"tool_manager": AgentV2ToolManager(
tenant_id=self._dify_context.tenant_id,
app_id=self._dify_context.app_id,
),
"event_adapter": AgentV2EventAdapter(),
"sandbox": self._resolve_sandbox(),
"memory": memory,
}
def _resolve_sandbox(self) -> Any:
"""Resolve sandbox from run_context, if available."""
return self.graph_init_params.run_context.get(DIFY_SANDBOX_CONTEXT_KEY)

View File

@ -73,6 +73,7 @@ class AgentV2Node(Node[AgentV2NodeData]):
tool_manager: AgentV2ToolManager,
event_adapter: AgentV2EventAdapter,
sandbox: Any | None = None,
memory: Any | None = None,
) -> None:
super().__init__(
id=id,
@ -83,6 +84,7 @@ class AgentV2Node(Node[AgentV2NodeData]):
self._tool_manager = tool_manager
self._event_adapter = event_adapter
self._sandbox = sandbox
self._memory = memory
@classmethod
def version(cls) -> str:
@ -332,12 +334,29 @@ class AgentV2Node(Node[AgentV2NodeData]):
resolved = self._resolve_variable_template(text_content, variable_pool)
messages.append(UserPromptMessage(content=resolved))
if self.node_data.memory:
history = self._load_memory_messages(dify_ctx)
if history:
system_msgs = [m for m in messages if isinstance(m, SystemPromptMessage)]
other_msgs = [m for m in messages if not isinstance(m, SystemPromptMessage)]
messages = system_msgs + history + other_msgs
if self._memory is not None:
try:
window_size = None
if self.node_data.memory and hasattr(self.node_data.memory, "window"):
w = self.node_data.memory.window
if w and w.enabled:
window_size = w.size
history = self._memory.get_history_prompt_messages(
max_token_limit=2000,
message_limit=window_size or 50,
)
history_list = list(history)
logger.info("[AGENT_V2_MEMORY] Loaded %d history messages from memory", len(history_list))
if history_list:
system_msgs = [m for m in messages if isinstance(m, SystemPromptMessage)]
other_msgs = [m for m in messages if not isinstance(m, SystemPromptMessage)]
messages = system_msgs + history_list + other_msgs
logger.info("[AGENT_V2_MEMORY] Total prompt messages after memory injection: %d", len(messages))
except Exception:
logger.warning("Failed to load memory for agent-v2 node", exc_info=True)
else:
logger.info("[AGENT_V2_MEMORY] No memory injected (self._memory is None)")
return messages