diff --git a/api/apps/memories_app.py b/api/apps/memories_app.py index 66fcabb4c..72e3e5d72 100644 --- a/api/apps/memories_app.py +++ b/api/apps/memories_app.py @@ -21,6 +21,7 @@ from api.db import TenantPermission from api.db.services.memory_service import MemoryService from api.db.services.user_service import UserTenantService from api.db.services.canvas_service import UserCanvasService +from api.db.services.task_service import TaskService from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human @@ -220,9 +221,17 @@ async def get_memory_detail(memory_id): messages = MessageService.list_message( memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) agent_name_mapping = {} + extract_task_mapping = {} if messages["message_list"]: agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]]) agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list} + task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id]) + if task_list: + task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task + for task in task_list: + # the 'digest' field carries the source_id when a task is created, so use 'digest' as key + extract_task_mapping.update({int(task["digest"]): task}) for message in messages["message_list"]: message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown") + message["task"] = extract_task_mapping.get(message["message_id"], {}) return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 79848cad5..490a16ac2 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -16,7 +16,6 @@ import logging from typing import List -from api.db.services.task_service import TaskService from common import settings from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms from common.constants import MemoryType, LLMType @@ -24,6 +23,7 @@ from common.doc_store.doc_store_base import FusionExpr from common.misc_utils import get_uuid from api.db.db_utils import bulk_insert_into_db from api.db.db_models import Task +from api.db.services.task_service import TaskService from api.db.services.memory_service import MemoryService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.llm_service import LLMBundle @@ -90,13 +90,19 @@ async def save_to_memory(memory_id: str, message_dict: dict): return await embed_and_save(memory, message_list) -async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int): +async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int, task_id: str=None): memory = MemoryService.get_by_memory_id(memory_id) if not memory: - return False, f"Memory '{memory_id}' not found." + msg = f"Memory '{memory_id}' not found." + if task_id: + TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) + return False, msg if memory.memory_type == MemoryType.RAW.value: - return True, f"Memory '{memory_id}' don't need to extract." + msg = f"Memory '{memory_id}' don't need to extract." + if task_id: + TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) + return True, msg tenant_id = memory.tenant_id extracted_content = await extract_by_llm( @@ -105,7 +111,8 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes {"temperature": memory.temperature}, get_memory_type_human(memory.memory_type), message_dict.get("user_input", ""), - message_dict.get("agent_response", "") + message_dict.get("agent_response", ""), + task_id=task_id ) message_list = [{ "message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"), @@ -122,13 +129,18 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes "status": True } for content in extracted_content] if not message_list: - return True, "No memory extracted from raw message." + msg = "No memory extracted from raw message." + if task_id: + TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) + return True, msg - return await embed_and_save(memory, message_list) + if task_id: + TaskService.update_progress(task_id, {"progress": 0.5, "progress_msg": timestamp_to_date(current_timestamp())+ " " + f"Extracted {len(message_list)} messages from raw dialogue."}) + return await embed_and_save(memory, message_list, task_id) async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str, - agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]: + agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None) -> List[dict]: llm_type = TenantLLMService.llm_id2llm_type(llm_id) if not llm_type: raise RuntimeError(f"Unknown type of LLM '{llm_id}'") @@ -143,8 +155,12 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory else: user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)}) llm = LLMBundle(tenant_id, llm_type, llm_id) + if task_id: + TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."}) res = await llm.async_chat(system_prompt, user_prompts, extract_conf) res_json = get_json_result_from_llm_response(res) + if task_id: + TaskService.update_progress(task_id, {"progress": 0.35, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Get extracted result from LLM."}) return [{ "content": extracted_content["content"], "valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]), @@ -153,16 +169,23 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory } for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list] -async def embed_and_save(memory, message_list: list[dict]): +async def embed_and_save(memory, message_list: list[dict], task_id: str=None): embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id) + if task_id: + TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."}) vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list]) for idx, msg in enumerate(message_list): msg["content_embed"] = vector_list[idx] + if task_id: + TaskService.update_progress(task_id, {"progress": 0.85, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Embedded extracted content."}) vector_dimension = len(vector_list[0]) if not MessageService.has_index(memory.tenant_id, memory.id): created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension) if not created: - return False, "Failed to create message index." + error_msg = "Failed to create message index." + if task_id: + TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg}) + return False, error_msg new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list]) current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id) @@ -174,11 +197,19 @@ async def embed_and_save(memory, message_list: list[dict]): MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id) decrease_memory_size_cache(memory.id, delete_size) else: - return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete." + error_msg = "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete." + if task_id: + TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg}) + return False, error_msg fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id) if fail_cases: - return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases) + error_msg = "Failed to insert message into memory. Details: " + "; ".join(fail_cases) + if task_id: + TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg}) + return False, error_msg + if task_id: + TaskService.update_progress(task_id, {"progress": 0.95, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Saved messages to storage."}) increase_memory_size_cache(memory.id, new_msg_size) return True, "Message saved successfully." @@ -379,11 +410,11 @@ async def handle_save_to_memory_task(task_param: dict): memory_id = task_param["memory_id"] source_id = task_param["source_id"] message_dict = task_param["message_dict"] - success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id) + success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id, task.id) if success: - TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": msg}) + TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) return True, msg logging.error(msg) - TaskService.update_progress(task.id, {"progress": -1, "progress_msg": None}) + TaskService.update_progress(task.id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) return False, msg diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 065d2376d..028381b44 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -179,6 +179,40 @@ class TaskService(CommonService): return None return tasks + @classmethod + @DB.connection_context() + def get_tasks_progress_by_doc_ids(cls, doc_ids: list[str]): + """Retrieve all tasks associated with specific documents. + + This method fetches all processing tasks for given document ids, ordered by + creation time. It includes task progress and chunk information. + + Args: + doc_ids (str): The unique identifier of the document. + + Returns: + list[dict]: List of task dictionaries containing task details. + Returns None if no tasks are found. + """ + fields = [ + cls.model.id, + cls.model.doc_id, + cls.model.from_page, + cls.model.progress, + cls.model.progress_msg, + cls.model.digest, + cls.model.chunk_ids, + cls.model.create_time + ] + tasks = ( + cls.model.select(*fields).order_by(cls.model.create_time.desc()) + .where(cls.model.doc_id.in_(doc_ids)) + ) + tasks = list(tasks.dicts()) + if not tasks: + return None + return tasks + @classmethod @DB.connection_context() def update_chunk_ids(cls, id: str, chunk_ids: str): diff --git a/memory/services/messages.py b/memory/services/messages.py index 0b41754c8..fe855905c 100644 --- a/memory/services/messages.py +++ b/memory/services/messages.py @@ -17,6 +17,7 @@ import sys from typing import List from common import settings +from common.constants import MemoryType from common.doc_store.doc_store_base import OrderByExpr, MatchExpr @@ -69,15 +70,16 @@ class MessageService: filter_dict["agent_id"] = agent_ids if keywords: filter_dict["session_id"] = keywords + select_fields = [ + "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", + "invalid_at", "forget_at", "status" + ] order_by = OrderByExpr() order_by.desc("valid_at") res, total_count = settings.msgStoreConn.search( - select_fields=[ - "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", - "invalid_at", "forget_at", "status" - ], + select_fields=select_fields, highlight_fields=[], - condition=filter_dict, + condition={**filter_dict, "message_type": MemoryType.RAW.name.lower()}, match_expressions=[], order_by=order_by, offset=(page-1)*page_size, limit=page_size, index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False @@ -88,12 +90,30 @@ class MessageService: "total_count": 0 } - doc_mapping = settings.msgStoreConn.get_fields(res, [ - "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", - "valid_at", "invalid_at", "forget_at", "status" - ]) + raw_msg_mapping = settings.msgStoreConn.get_fields(res, select_fields) + raw_messages = list(raw_msg_mapping.values()) + extract_filter = {"source_id": [r["message_id"] for r in raw_messages]} + extract_res, _ = settings.msgStoreConn.search( + select_fields=select_fields, + highlight_fields=[], + condition=extract_filter, + match_expressions=[], order_by=order_by, + offset=0, limit=512, + index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False + ) + extract_msg = settings.msgStoreConn.get_fields(extract_res, select_fields) + grouped_extract_msg = {} + for msg in extract_msg.values(): + if grouped_extract_msg.get(msg["source_id"]): + grouped_extract_msg[msg["source_id"]].append(msg) + else: + grouped_extract_msg[msg["source_id"]] = [msg] + + for raw_msg in raw_messages: + raw_msg["extract"] = grouped_extract_msg.get(raw_msg["message_id"], []) + return { - "message_list": list(doc_mapping.values()), + "message_list": raw_messages, "total_count": total_count }