diff --git a/agent/component/message.py b/agent/component/message.py index 621240bf6..fce572c8b 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -433,7 +433,7 @@ class Message(ComponentBase): return True, "No memory selected." message_dict = { - "user_id": self._canvas._tenant_id, + "user_id": self._param.user_id if hasattr(self._param, "user_id") else "", "agent_id": self._canvas._id, "session_id": self._canvas.task_id, "user_input": self._canvas.get_sys_query(), diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index abf82fa1a..6fde0a2cc 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -259,6 +259,7 @@ class Retrieval(ToolBase, ABC): async def _retrieve_memory(self, query_text: str): memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids] + user_id: str = self._param.user_id if hasattr(self._param, "user_id") else None memory_list = MemoryService.get_by_ids(memory_ids) if not memory_list: raise Exception("No memory is selected.") @@ -270,7 +271,10 @@ class Retrieval(ToolBase, ABC): vars = {k: o["value"] for k, o in vars.items()} query = self.string_format(query_text, vars) # query message - message_list = memory_message_service.query_message({"memory_id": memory_ids}, { + filter_dict: dict = {"memory_id": memory_ids} + if user_id: + filter_dict["user_id"] = user_id + message_list = memory_message_service.query_message(filter_dict, { "query": query, "similarity_threshold": self._param.similarity_threshold, "keywords_similarity_weight": self._param.keywords_similarity_weight, diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 7238e480a..79a85d631 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -249,11 +249,13 @@ async def search_message(): top_n = int(args.get("top_n", 5)) agent_id = args.get("agent_id", "") session_id = args.get("session_id", "") + user_id = args.get("user_id", "") filter_dict = { "memory_id": memory_ids, "agent_id": agent_id, - "session_id": session_id + "session_id": session_id, + "user_id": user_id } params = { "query": query, diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index e6aa0de2a..ca6627d5a 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -291,6 +291,7 @@ async def search_message(filter_dict: dict, params: dict): "memory_id": list[str], "agent_id": str, "session_id": str + "user_id": str } :param params: { "query": str, diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 9c07f5fbb..4765b2bdb 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -66,7 +66,7 @@ async def save_to_memory(memory_id: str, message_dict: dict): "message_type": MemoryType.RAW.name.lower(), "source_id": 0, "memory_id": memory_id, - "user_id": "", + "user_id": message_dict.get("user_id", ""), "agent_id": message_dict["agent_id"], "session_id": message_dict["session_id"], "content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}", @@ -79,7 +79,7 @@ async def save_to_memory(memory_id: str, message_dict: dict): "message_type": content["message_type"], "source_id": raw_message_id, "memory_id": memory_id, - "user_id": "", + "user_id": message_dict.get("user_id", ""), "agent_id": message_dict["agent_id"], "session_id": message_dict["session_id"], "content": content["content"], @@ -121,7 +121,7 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes "message_type": content["message_type"], "source_id": source_message_id, "memory_id": memory_id, - "user_id": "", + "user_id": message_dict.get("user_id", ""), "agent_id": message_dict["agent_id"], "session_id": message_dict["session_id"], "content": content["content"], @@ -227,6 +227,7 @@ def query_message(filter_dict: dict, params: dict): "memory_id": List[str], "agent_id": optional "session_id": optional + "user_id": optional } :param params: { "query": question str, @@ -374,7 +375,7 @@ async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict): "message_type": MemoryType.RAW.name.lower(), "source_id": 0, "memory_id": memory_id, - "user_id": "", + "user_id": message_dict.get("user_id", ""), "agent_id": message_dict["agent_id"], "session_id": message_dict["session_id"], "content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",