mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 18:27:27 +08:00
fix(api): resolve multi-turn memory failure in Agent apps
- Auto-resolve parent_message_id when not provided by client, querying the latest message in the conversation to maintain the thread chain that extract_thread_messages() relies on. - Add AppMode.AGENT to TokenBufferMemory mode checks so file attachments in memory are handled via the workflow branch. - Add debug logging for memory injection in node_factory and node. Made-with: Cursor
This commit is contained in:
@ -177,6 +177,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
# Resolve parent_message_id for thread continuity
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
parent_message_id: str | None = UUID_NIL
|
||||
else:
|
||||
parent_message_id = args.get("parent_message_id")
|
||||
if not parent_message_id and conversation:
|
||||
parent_message_id = self._resolve_latest_message_id(conversation.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@ -188,7 +196,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=parent_message_id,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
@ -689,3 +697,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
else:
|
||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _resolve_latest_message_id(conversation_id: str) -> str | None:
|
||||
"""Auto-resolve parent_message_id to the latest message when client doesn't provide one."""
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = (
|
||||
select(Message.id)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest_id = db.session.scalar(stmt)
|
||||
return str(latest_id) if latest_id else None
|
||||
|
||||
@ -63,7 +63,7 @@ class TokenBufferMemory:
|
||||
"""
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT}:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
|
||||
@ -402,14 +402,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
"runtime_support": self._agent_runtime_support,
|
||||
"message_transformer": self._agent_message_transformer,
|
||||
},
|
||||
AGENT_V2_NODE_TYPE: lambda: {
|
||||
"tool_manager": AgentV2ToolManager(
|
||||
tenant_id=self._dify_context.tenant_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
),
|
||||
"event_adapter": AgentV2EventAdapter(),
|
||||
"sandbox": self._resolve_sandbox(),
|
||||
},
|
||||
AGENT_V2_NODE_TYPE: lambda: self._build_agent_v2_kwargs(node_data),
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
@ -420,6 +413,55 @@ class DifyNodeFactory(NodeFactory):
|
||||
**node_init_kwargs,
|
||||
)
|
||||
|
||||
def _build_agent_v2_kwargs(self, node_data: BaseNodeData) -> dict[str, object]:
|
||||
"""Build initialization kwargs for Agent V2 node.
|
||||
|
||||
Injects memory (same mechanism as LLM Node) plus tool_manager,
|
||||
event_adapter, and sandbox.
|
||||
"""
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
validated = AgentV2NodeData.model_validate(node_data.model_dump())
|
||||
|
||||
import logging as _logging
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
memory = None
|
||||
if validated.memory is not None:
|
||||
conversation_id = get_system_text(
|
||||
self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID
|
||||
)
|
||||
_log.info("[AGENT_V2_MEMORY] memory_config=%s, conversation_id=%s", validated.memory, conversation_id)
|
||||
if conversation_id:
|
||||
from graphon.model_runtime.entities.model_entities import ModelType as _ModelType
|
||||
|
||||
from core.model_manager import ModelManager as _ModelManager
|
||||
|
||||
model_instance = _ModelManager.for_tenant(
|
||||
tenant_id=self._dify_context.tenant_id
|
||||
).get_model_instance(
|
||||
tenant_id=self._dify_context.tenant_id,
|
||||
provider=validated.model.provider,
|
||||
model_type=_ModelType.LLM,
|
||||
model=validated.model.name,
|
||||
)
|
||||
memory = fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
node_data_memory=validated.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
return {
|
||||
"tool_manager": AgentV2ToolManager(
|
||||
tenant_id=self._dify_context.tenant_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
),
|
||||
"event_adapter": AgentV2EventAdapter(),
|
||||
"sandbox": self._resolve_sandbox(),
|
||||
"memory": memory,
|
||||
}
|
||||
|
||||
def _resolve_sandbox(self) -> Any:
|
||||
"""Resolve sandbox from run_context, if available."""
|
||||
return self.graph_init_params.run_context.get(DIFY_SANDBOX_CONTEXT_KEY)
|
||||
|
||||
@ -73,6 +73,7 @@ class AgentV2Node(Node[AgentV2NodeData]):
|
||||
tool_manager: AgentV2ToolManager,
|
||||
event_adapter: AgentV2EventAdapter,
|
||||
sandbox: Any | None = None,
|
||||
memory: Any | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -83,6 +84,7 @@ class AgentV2Node(Node[AgentV2NodeData]):
|
||||
self._tool_manager = tool_manager
|
||||
self._event_adapter = event_adapter
|
||||
self._sandbox = sandbox
|
||||
self._memory = memory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -332,12 +334,29 @@ class AgentV2Node(Node[AgentV2NodeData]):
|
||||
resolved = self._resolve_variable_template(text_content, variable_pool)
|
||||
messages.append(UserPromptMessage(content=resolved))
|
||||
|
||||
if self.node_data.memory:
|
||||
history = self._load_memory_messages(dify_ctx)
|
||||
if history:
|
||||
system_msgs = [m for m in messages if isinstance(m, SystemPromptMessage)]
|
||||
other_msgs = [m for m in messages if not isinstance(m, SystemPromptMessage)]
|
||||
messages = system_msgs + history + other_msgs
|
||||
if self._memory is not None:
|
||||
try:
|
||||
window_size = None
|
||||
if self.node_data.memory and hasattr(self.node_data.memory, "window"):
|
||||
w = self.node_data.memory.window
|
||||
if w and w.enabled:
|
||||
window_size = w.size
|
||||
|
||||
history = self._memory.get_history_prompt_messages(
|
||||
max_token_limit=2000,
|
||||
message_limit=window_size or 50,
|
||||
)
|
||||
history_list = list(history)
|
||||
logger.info("[AGENT_V2_MEMORY] Loaded %d history messages from memory", len(history_list))
|
||||
if history_list:
|
||||
system_msgs = [m for m in messages if isinstance(m, SystemPromptMessage)]
|
||||
other_msgs = [m for m in messages if not isinstance(m, SystemPromptMessage)]
|
||||
messages = system_msgs + history_list + other_msgs
|
||||
logger.info("[AGENT_V2_MEMORY] Total prompt messages after memory injection: %d", len(messages))
|
||||
except Exception:
|
||||
logger.warning("Failed to load memory for agent-v2 node", exc_info=True)
|
||||
else:
|
||||
logger.info("[AGENT_V2_MEMORY] No memory injected (self._memory is None)")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
Reference in New Issue
Block a user