From 6e7bcf58bc03746e7a79fa449ff15dafd061efdb Mon Sep 17 00:00:00 2001 From: Lynn Date: Thu, 12 Feb 2026 14:43:52 +0800 Subject: [PATCH] Refactor: split message apis to gateway and service (#13126) ### What problem does this PR solve? Split message apis to gateway and service ### Type of change - [x] Refactoring --- api/apps/restful_apis/memory_api.py | 121 ++++++++++++++++++ api/apps/sdk/messages.py | 158 ------------------------ api/apps/services/memory_api_service.py | 114 ++++++++++++++++- web/src/routes.tsx | 7 +- 4 files changed, 240 insertions(+), 160 deletions(-) delete mode 100644 api/apps/sdk/messages.py diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 53c7f866e..c1cd9e5a9 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -171,3 +171,124 @@ async def get_memory_messages(memory_id): except Exception as e: logging.error(e) return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/messages", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response") +async def add_message(): + req = await get_request_json() + memory_ids = req["memory_id"] + + message_dict = { + "user_id": req.get("user_id"), + "agent_id": req["agent_id"], + "session_id": req["session_id"], + "user_input": req["user_input"], + "agent_response": req["agent_response"], + } + + res, msg = await memory_api_service.add_message(memory_ids, message_dict) + if res: + return get_json_result(message=msg) + + return get_json_result(message="Some messages failed to add. Detail:" + msg, code=RetCode.SERVER_ERROR) + + +@manager.route("/messages/:", methods=["DELETE"]) # noqa: F821 +@login_required +async def forget_message(memory_id: str, message_id: int): + try: + res = await memory_api_service.forget_message(memory_id, message_id) + return get_json_result(message=res) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/messages/:", methods=["PUT"]) # noqa: F821 +@login_required +@validate_request("status") +async def update_message(memory_id: str, message_id: int): + req = await get_request_json() + status = req["status"] + if not isinstance(status, bool): + return get_error_argument_result("Status must be a boolean.") + + try: + update_succeed = await memory_api_service.update_message_status(memory_id, message_id, status) + if update_succeed: + return get_json_result(message=update_succeed) + else: + return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/messages/search", methods=["GET"]) # noqa: F821 +@login_required +async def search_message(): + args = request.args + memory_ids = args.getlist("memory_id") + if len(memory_ids) == 1 and ',' in memory_ids[0]: + memory_ids = memory_ids[0].split(',') + query = args.get("query") + similarity_threshold = float(args.get("similarity_threshold", 0.2)) + keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7)) + top_n = int(args.get("top_n", 5)) + agent_id = args.get("agent_id", "") + session_id = args.get("session_id", "") + + filter_dict = { + "memory_id": memory_ids, + "agent_id": agent_id, + "session_id": session_id + } + params = { + "query": query, + "similarity_threshold": similarity_threshold, + "keywords_similarity_weight": keywords_similarity_weight, + "top_n": top_n + } + res = await memory_api_service.search_message(filter_dict, params) + return get_json_result(message=True, data=res) + +@manager.route("/messages", methods=["GET"]) # noqa: F821 +@login_required +async def get_messages(): + args = request.args + memory_ids = args.getlist("memory_id") + if len(memory_ids) == 1 and ',' in memory_ids[0]: + memory_ids = memory_ids[0].split(',') + agent_id = args.get("agent_id", "") + session_id = args.get("session_id", "") + limit = int(args.get("limit", 10)) + if not memory_ids: + return get_error_argument_result("memory_ids is required.") + try: + res = await memory_api_service.get_messages(memory_ids, agent_id, session_id, limit) + return get_json_result(message=True, data=res) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/messages/:/content", methods=["GET"]) # noqa: F821 +@login_required +async def get_message_content(memory_id: str, message_id: int): + try: + res = await memory_api_service.get_message_content(memory_id, message_id) + return get_json_result(message=True, data=res) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") diff --git a/api/apps/sdk/messages.py b/api/apps/sdk/messages.py deleted file mode 100644 index 5ed590218..000000000 --- a/api/apps/sdk/messages.py +++ /dev/null @@ -1,158 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from quart import request -from api.apps import login_required -from api.db.services.memory_service import MemoryService -from common.time_utils import current_timestamp, timestamp_to_date - -from memory.services.messages import MessageService -from api.db.joint_services import memory_message_service -from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result -from common.constants import RetCode - - -@manager.route("/messages", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response") -async def add_message(): - - req = await get_request_json() - memory_ids = req["memory_id"] - - message_dict = { - "user_id": req.get("user_id"), - "agent_id": req["agent_id"], - "session_id": req["session_id"], - "user_input": req["user_input"], - "agent_response": req["agent_response"], - } - - res, msg = await memory_message_service.queue_save_to_memory_task(memory_ids, message_dict) - - if res: - return get_json_result(message=msg) - - return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add. Detail:" + msg) - - -@manager.route("/messages/:", methods=["DELETE"]) # noqa: F821 -@login_required -async def forget_message(memory_id: str, message_id: int): - - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - - forget_time = timestamp_to_date(current_timestamp()) - update_succeed = MessageService.update_message( - {"memory_id": memory_id, "message_id": int(message_id)}, - {"forget_at": forget_time}, - memory.tenant_id, memory_id) - if update_succeed: - return get_json_result(message=update_succeed) - else: - return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.") - - -@manager.route("/messages/:", methods=["PUT"]) # noqa: F821 -@login_required -@validate_request("status") -async def update_message(memory_id: str, message_id: int): - req = await get_request_json() - status = req["status"] - if not isinstance(status, bool): - return get_error_argument_result("Status must be a boolean.") - - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - - update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id) - if update_succeed: - return get_json_result(message=update_succeed) - else: - return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") - - -@manager.route("/messages/search", methods=["GET"]) # noqa: F821 -@login_required -async def search_message(): - args = request.args - empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)] - if empty_fields: - return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.") - - memory_ids = args.getlist("memory_id") - if len(memory_ids) == 1 and ',' in memory_ids[0]: - memory_ids = memory_ids[0].split(',') - query = args.get("query") - similarity_threshold = float(args.get("similarity_threshold", 0.2)) - keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7)) - top_n = int(args.get("top_n", 5)) - agent_id = args.get("agent_id", "") - session_id = args.get("session_id", "") - - filter_dict = { - "memory_id": memory_ids, - "agent_id": agent_id, - "session_id": session_id - } - params = { - "query": query, - "similarity_threshold": similarity_threshold, - "keywords_similarity_weight": keywords_similarity_weight, - "top_n": top_n - } - res = memory_message_service.query_message(filter_dict, params) - return get_json_result(message=True, data=res) - - -@manager.route("/messages", methods=["GET"]) # noqa: F821 -@login_required -async def get_messages(): - args = request.args - memory_ids = args.getlist("memory_id") - if len(memory_ids) == 1 and ',' in memory_ids[0]: - memory_ids = memory_ids[0].split(',') - agent_id = args.get("agent_id", "") - session_id = args.get("session_id", "") - limit = int(args.get("limit", 10)) - if not memory_ids: - return get_error_argument_result("memory_ids is required.") - memory_list = MemoryService.get_by_ids(memory_ids) - uids = [memory.tenant_id for memory in memory_list] - res = MessageService.get_recent_messages( - uids, - memory_ids, - agent_id, - session_id, - limit - ) - return get_json_result(message=True, data=res) - - -@manager.route("/messages/:/content", methods=["GET"]) # noqa: F821 -@login_required -async def get_message_content(memory_id:str, message_id: int): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - - res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id) - if res: - return get_json_result(message=True, data=res) - else: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.") diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 53bb0f6e9..e49fe7ed0 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -19,13 +19,14 @@ from api.db.services.memory_service import MemoryService from api.db.services.user_service import UserTenantService from api.db.services.canvas_service import UserCanvasService from api.db.services.task_service import TaskService -from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default +from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT from memory.services.messages import MessageService from memory.utils.prompt_util import PromptAssembler from common.constants import MemoryType, ForgettingPolicy from common.exceptions import ArgumentException, NotFoundException +from common.time_utils import current_timestamp, timestamp_to_date async def create_memory(memory_info: dict): @@ -169,6 +170,16 @@ async def delete_memory(memory_id): async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size: int = 50): + """ + :param filter_params: { + "memory_type": list[str], + "tenant_id": list[str], + "storage_type": str + } + :param keywords: str + :param page: int + :param page_size: int + """ filter_dict: dict = {"storage_type": filter_params.get("storage_type")} tenant_ids = filter_params.get("tenant_id") if not filter_params.get("tenant_id"): @@ -221,3 +232,104 @@ async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, pa for extract_msg in message["extract"]: extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown") return {"messages": messages, "storage_type": memory.storage_type} + + +async def add_message(memory_ids: list[str], message_dict: dict): + """ + :param memory_ids: list[str] + :param message_dict: { + "agent_id": str, + "session_id": str, + "user_input": str, + "agent_response": str, + "message_type": str + } + """ + return await queue_save_to_memory_task(memory_ids, message_dict) + + +async def forget_message(memory_id: str, message_id: int): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + + forget_time = timestamp_to_date(current_timestamp()) + update_succeed = MessageService.update_message( + {"memory_id": memory_id, "message_id": int(message_id)}, + {"forget_at": forget_time}, + memory.tenant_id, memory_id) + if update_succeed: + return True + raise Exception(f"Failed to forget message '{message_id}' in memory '{memory_id}'.") + + +async def update_message_status(memory_id: str, message_id: int, status: bool): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + + update_succeed = MessageService.update_message( + {"memory_id": memory_id, "message_id": int(message_id)}, + {"status": status}, + memory.tenant_id, memory_id) + if update_succeed: + return True + raise Exception(f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") + + +async def search_message(filter_dict: dict, params: dict): + """ + :param filter_dict: { + "memory_id": list[str], + "agent_id": str, + "session_id": str + } + :param params: { + "query": str, + "similarity_threshold": float, + "keywords_similarity_weight": float, + "top_n": int + } + """ + return query_message(filter_dict, params) + + +async def get_messages(memory_ids: list[str], agent_id: str = "", session_id: str = "", limit: int = 10): + """ + Get recent messages from specified memories. + + :param memory_ids: list of memory IDs + :param agent_id: optional agent ID for filtering + :param session_id: optional session ID for filtering + :param limit: maximum number of messages to return + :return: list of recent messages + """ + memory_list = MemoryService.get_by_ids(memory_ids) + uids = [memory.tenant_id for memory in memory_list] + res = MessageService.get_recent_messages( + uids, + memory_ids, + agent_id, + session_id, + limit + ) + return res + + +async def get_message_content(memory_id: str, message_id: int): + """ + Get content of a specific message from a memory. + + :param memory_id: memory ID + :param message_id: message ID + :return: message content + :raises NotFoundException: if memory or message not found + """ + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + + res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id) + if res: + return res + raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.") \ No newline at end of file diff --git a/web/src/routes.tsx b/web/src/routes.tsx index 300382dd2..8c6d538a6 100644 --- a/web/src/routes.tsx +++ b/web/src/routes.tsx @@ -1,5 +1,10 @@ import { lazy, memo, Suspense } from 'react'; -import { createBrowserRouter, Navigate, redirect, type RouteObject } from 'react-router'; +import { + createBrowserRouter, + Navigate, + redirect, + type RouteObject, +} from 'react-router'; import FallbackComponent from './components/fallback-component'; import { IS_ENTERPRISE } from './pages/admin/utils'; import authorizationUtil from './utils/authorization-util';