mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-19 11:45:10 +08:00
Feat: process memory (#12445)
### What problem does this PR solve? Add task status for raw message, and move extract message as a nested property under raw message ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -21,6 +21,7 @@ from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
@ -220,9 +221,17 @@ async def get_memory_detail(memory_id):
|
||||
messages = MessageService.list_message(
|
||||
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
|
||||
agent_name_mapping = {}
|
||||
extract_task_mapping = {}
|
||||
if messages["message_list"]:
|
||||
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
|
||||
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
|
||||
task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
|
||||
if task_list:
|
||||
task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
|
||||
for task in task_list:
|
||||
# the 'digest' field carries the source_id when a task is created, so use 'digest' as key
|
||||
extract_task_mapping.update({int(task["digest"]): task})
|
||||
for message in messages["message_list"]:
|
||||
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
|
||||
message["task"] = extract_task_mapping.get(message["message_id"], {})
|
||||
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from api.db.services.task_service import TaskService
|
||||
from common import settings
|
||||
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
|
||||
from common.constants import MemoryType, LLMType
|
||||
@ -24,6 +23,7 @@ from common.doc_store.doc_store_base import FusionExpr
|
||||
from common.misc_utils import get_uuid
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.db_models import Task
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
@ -90,13 +90,19 @@ async def save_to_memory(memory_id: str, message_dict: dict):
|
||||
return await embed_and_save(memory, message_list)
|
||||
|
||||
|
||||
async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int):
|
||||
async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int, task_id: str=None):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return False, f"Memory '{memory_id}' not found."
|
||||
msg = f"Memory '{memory_id}' not found."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return False, msg
|
||||
|
||||
if memory.memory_type == MemoryType.RAW.value:
|
||||
return True, f"Memory '{memory_id}' don't need to extract."
|
||||
msg = f"Memory '{memory_id}' don't need to extract."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return True, msg
|
||||
|
||||
tenant_id = memory.tenant_id
|
||||
extracted_content = await extract_by_llm(
|
||||
@ -105,7 +111,8 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
|
||||
{"temperature": memory.temperature},
|
||||
get_memory_type_human(memory.memory_type),
|
||||
message_dict.get("user_input", ""),
|
||||
message_dict.get("agent_response", "")
|
||||
message_dict.get("agent_response", ""),
|
||||
task_id=task_id
|
||||
)
|
||||
message_list = [{
|
||||
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||
@ -122,13 +129,18 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
|
||||
"status": True
|
||||
} for content in extracted_content]
|
||||
if not message_list:
|
||||
return True, "No memory extracted from raw message."
|
||||
msg = "No memory extracted from raw message."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return True, msg
|
||||
|
||||
return await embed_and_save(memory, message_list)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.5, "progress_msg": timestamp_to_date(current_timestamp())+ " " + f"Extracted {len(message_list)} messages from raw dialogue."})
|
||||
return await embed_and_save(memory, message_list, task_id)
|
||||
|
||||
|
||||
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
|
||||
agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None) -> List[dict]:
|
||||
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
|
||||
if not llm_type:
|
||||
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
|
||||
@ -143,8 +155,12 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
|
||||
else:
|
||||
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
||||
llm = LLMBundle(tenant_id, llm_type, llm_id)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."})
|
||||
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
|
||||
res_json = get_json_result_from_llm_response(res)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.35, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Get extracted result from LLM."})
|
||||
return [{
|
||||
"content": extracted_content["content"],
|
||||
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
|
||||
@ -153,16 +169,23 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
|
||||
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
|
||||
|
||||
|
||||
async def embed_and_save(memory, message_list: list[dict]):
|
||||
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
|
||||
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."})
|
||||
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||
for idx, msg in enumerate(message_list):
|
||||
msg["content_embed"] = vector_list[idx]
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.85, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Embedded extracted content."})
|
||||
vector_dimension = len(vector_list[0])
|
||||
if not MessageService.has_index(memory.tenant_id, memory.id):
|
||||
created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension)
|
||||
if not created:
|
||||
return False, "Failed to create message index."
|
||||
error_msg = "Failed to create message index."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
|
||||
return False, error_msg
|
||||
|
||||
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
||||
current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id)
|
||||
@ -174,11 +197,19 @@ async def embed_and_save(memory, message_list: list[dict]):
|
||||
MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id)
|
||||
decrease_memory_size_cache(memory.id, delete_size)
|
||||
else:
|
||||
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||
error_msg = "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
|
||||
return False, error_msg
|
||||
fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id)
|
||||
if fail_cases:
|
||||
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||
error_msg = "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
|
||||
return False, error_msg
|
||||
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.95, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Saved messages to storage."})
|
||||
increase_memory_size_cache(memory.id, new_msg_size)
|
||||
return True, "Message saved successfully."
|
||||
|
||||
@ -379,11 +410,11 @@ async def handle_save_to_memory_task(task_param: dict):
|
||||
memory_id = task_param["memory_id"]
|
||||
source_id = task_param["source_id"]
|
||||
message_dict = task_param["message_dict"]
|
||||
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id)
|
||||
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id, task.id)
|
||||
if success:
|
||||
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": msg})
|
||||
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return True, msg
|
||||
|
||||
logging.error(msg)
|
||||
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": None})
|
||||
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return False, msg
|
||||
|
||||
@ -179,6 +179,40 @@ class TaskService(CommonService):
|
||||
return None
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tasks_progress_by_doc_ids(cls, doc_ids: list[str]):
|
||||
"""Retrieve all tasks associated with specific documents.
|
||||
|
||||
This method fetches all processing tasks for given document ids, ordered by
|
||||
creation time. It includes task progress and chunk information.
|
||||
|
||||
Args:
|
||||
doc_ids (str): The unique identifier of the document.
|
||||
|
||||
Returns:
|
||||
list[dict]: List of task dictionaries containing task details.
|
||||
Returns None if no tasks are found.
|
||||
"""
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.doc_id,
|
||||
cls.model.from_page,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.digest,
|
||||
cls.model.chunk_ids,
|
||||
cls.model.create_time
|
||||
]
|
||||
tasks = (
|
||||
cls.model.select(*fields).order_by(cls.model.create_time.desc())
|
||||
.where(cls.model.doc_id.in_(doc_ids))
|
||||
)
|
||||
tasks = list(tasks.dicts())
|
||||
if not tasks:
|
||||
return None
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
||||
|
||||
Reference in New Issue
Block a user