mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-04 17:27:50 +08:00
Refactor: split memory API into gateway and service layers (#13111)
### What problem does this PR solve? Decouple the memory API into a gateway layer (for routing/param parse) and a service layer (for business logic). ### Type of change - [x] Refactoring
This commit is contained in:
@ -244,6 +244,10 @@ def search_pages_path(page_path):
|
||||
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list.extend(api_path_list)
|
||||
restful_api_path_list = [
|
||||
path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list.extend(restful_api_path_list)
|
||||
return app_path_list
|
||||
|
||||
|
||||
@ -263,8 +267,9 @@ def register_page(page_path):
|
||||
spec.loader.exec_module(page)
|
||||
page_name = getattr(page, "page_name", page_name)
|
||||
sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/"
|
||||
restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/"
|
||||
url_prefix = (
|
||||
f"/api/{API_VERSION}" if sdk_path in path else f"/{API_VERSION}/{page_name}"
|
||||
f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}"
|
||||
)
|
||||
|
||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||
@ -274,6 +279,7 @@ def register_page(page_path):
|
||||
pages_dir = [
|
||||
Path(__file__).parent,
|
||||
Path(__file__).parent.parent / "api" / "apps",
|
||||
Path(__file__).parent.parent / "api" / "apps" / "restful_apis",
|
||||
Path(__file__).parent.parent / "api" / "apps" / "sdk",
|
||||
]
|
||||
|
||||
|
||||
173
api/apps/restful_apis/memory_api.py
Normal file
173
api/apps/restful_apis/memory_api.py
Normal file
@ -0,0 +1,173 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from quart import request
|
||||
from common.constants import RetCode
|
||||
from common.exceptions import ArgumentException, NotFoundException
|
||||
from api.apps import login_required
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
||||
from api.apps.services import memory_api_service
|
||||
|
||||
|
||||
@manager.route("/memories", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
||||
async def create_memory():
|
||||
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
||||
t_start = time.perf_counter() if timing_enabled else None
|
||||
req = await get_request_json()
|
||||
t_parsed = time.perf_counter() if timing_enabled else None
|
||||
try:
|
||||
memory_info = {
|
||||
"name": req["name"],
|
||||
"memory_type": req["memory_type"],
|
||||
"embd_id": req["embd_id"],
|
||||
"llm_id": req["llm_id"]
|
||||
}
|
||||
success, res = await memory_api_service.create_memory(memory_info)
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory parse_ms=%.2f validate_and_db_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_parsed) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
if success:
|
||||
return get_json_result(message=True, data=res)
|
||||
else:
|
||||
return get_json_result(message=res, code=RetCode.SERVER_ERROR)
|
||||
|
||||
except ArgumentException as arg_error:
|
||||
logging.error(arg_error)
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
str(arg_error),
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result(str(arg_error))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
str(e),
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
async def update_memory(memory_id):
|
||||
req = await get_request_json()
|
||||
new_settings = {k: req[k] for k in [
|
||||
"name", "permissions", "llm_id", "embd_id", "memory_type", "memory_size", "forgetting_policy", "temperature",
|
||||
"avatar", "description", "system_prompt", "user_prompt"
|
||||
] if k in req}
|
||||
try:
|
||||
success, res = await memory_api_service.update_memory(memory_id, new_settings)
|
||||
if success:
|
||||
return get_json_result(message=True, data=res)
|
||||
else:
|
||||
return get_json_result(message=res, code=RetCode.SERVER_ERROR)
|
||||
except NotFoundException as not_found_exception:
|
||||
logging.error(not_found_exception)
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
||||
except ArgumentException as arg_error:
|
||||
logging.error(arg_error)
|
||||
return get_error_argument_result(str(arg_error))
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def delete_memory(memory_id):
|
||||
try:
|
||||
await memory_api_service.delete_memory(memory_id)
|
||||
return get_json_result(message=True)
|
||||
except NotFoundException as not_found_exception:
|
||||
logging.error(not_found_exception)
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/memories", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_memory():
|
||||
filter_params = {
|
||||
k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args
|
||||
}
|
||||
keywords = request.args.get("keywords")
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = int(request.args.get("page_size", 50))
|
||||
try:
|
||||
res = await memory_api_service.list_memory(filter_params, keywords, page, page_size)
|
||||
return get_json_result(message=True, data=res)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_config(memory_id):
|
||||
try:
|
||||
res = await memory_api_service.get_memory_config(memory_id)
|
||||
return get_json_result(message=True, data=res)
|
||||
except NotFoundException as not_found_exception:
|
||||
logging.error(not_found_exception)
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_messages(memory_id):
|
||||
args = request.args
|
||||
agent_ids = args.getlist("agent_id")
|
||||
if len(agent_ids) == 1 and ',' in agent_ids[0]:
|
||||
agent_ids = agent_ids[0].split(',')
|
||||
keywords = args.get("keywords", "")
|
||||
keywords = keywords.strip()
|
||||
page = int(args.get("page", 1))
|
||||
page_size = int(args.get("page_size", 50))
|
||||
try:
|
||||
res = await memory_api_service.get_memory_messages(
|
||||
memory_id, agent_ids, keywords, page, page_size
|
||||
)
|
||||
return get_json_result(message=True, data=res)
|
||||
except NotFoundException as not_found_exception:
|
||||
logging.error(not_found_exception)
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
||||
@ -1,291 +0,0 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||
from memory.services.messages import MessageService
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||
|
||||
|
||||
@manager.route("/memories", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
||||
async def create_memory():
|
||||
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
||||
t_start = time.perf_counter() if timing_enabled else None
|
||||
req = await get_request_json()
|
||||
t_parsed = time.perf_counter() if timing_enabled else None
|
||||
# check name length
|
||||
name = req["name"]
|
||||
memory_name = name.strip()
|
||||
if len(memory_name) == 0:
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||
# check memory_type valid
|
||||
if not isinstance(req["memory_type"], list):
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result("Memory type must be a list.")
|
||||
memory_type = set(req["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
if invalid_type:
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||
memory_type = list(memory_type)
|
||||
|
||||
try:
|
||||
t_before_db = time.perf_counter() if timing_enabled else None
|
||||
res, memory = MemoryService.create_memory(
|
||||
tenant_id=current_user.id,
|
||||
name=memory_name,
|
||||
memory_type=memory_type,
|
||||
embd_id=req["embd_id"],
|
||||
llm_id=req["llm_id"]
|
||||
)
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(t_before_db - t_parsed) * 1000,
|
||||
(time.perf_counter() - t_before_db) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
|
||||
if res:
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
else:
|
||||
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
|
||||
|
||||
except Exception as e:
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
async def update_memory(memory_id):
|
||||
req = await get_request_json()
|
||||
update_dict = {}
|
||||
# check name length
|
||||
if "name" in req:
|
||||
name = req["name"]
|
||||
memory_name = name.strip()
|
||||
if len(memory_name) == 0:
|
||||
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||
update_dict["name"] = memory_name
|
||||
# check permissions valid
|
||||
if req.get("permissions"):
|
||||
if req["permissions"] not in [e.value for e in TenantPermission]:
|
||||
return get_error_argument_result(f"Unknown permission '{req['permissions']}'.")
|
||||
update_dict["permissions"] = req["permissions"]
|
||||
if req.get("llm_id"):
|
||||
update_dict["llm_id"] = req["llm_id"]
|
||||
if req.get("embd_id"):
|
||||
update_dict["embd_id"] = req["embd_id"]
|
||||
if req.get("memory_type"):
|
||||
memory_type = set(req["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
if invalid_type:
|
||||
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||
update_dict["memory_type"] = list(memory_type)
|
||||
# check memory_size valid
|
||||
if req.get("memory_size"):
|
||||
if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT:
|
||||
return get_error_argument_result(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.")
|
||||
update_dict["memory_size"] = req["memory_size"]
|
||||
# check forgetting_policy valid
|
||||
if req.get("forgetting_policy"):
|
||||
if req["forgetting_policy"] not in [e.value for e in ForgettingPolicy]:
|
||||
return get_error_argument_result(f"Forgetting policy '{req['forgetting_policy']}' is not supported.")
|
||||
update_dict["forgetting_policy"] = req["forgetting_policy"]
|
||||
# check temperature valid
|
||||
if "temperature" in req:
|
||||
temperature = float(req["temperature"])
|
||||
if not 0 <= temperature <= 1:
|
||||
return get_error_argument_result("Temperature should be in range [0, 1].")
|
||||
update_dict["temperature"] = temperature
|
||||
# allow update to empty fields
|
||||
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
|
||||
if field in req:
|
||||
update_dict[field] = req[field]
|
||||
current_memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not current_memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
|
||||
memory_dict = current_memory.to_dict()
|
||||
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
|
||||
to_update = {}
|
||||
for k, v in update_dict.items():
|
||||
if isinstance(v, list) and set(memory_dict[k]) != set(v):
|
||||
to_update[k] = v
|
||||
elif memory_dict[k] != v:
|
||||
to_update[k] = v
|
||||
|
||||
if not to_update:
|
||||
return get_json_result(message=True, data=memory_dict)
|
||||
# check memory empty when update embd_id, memory_type
|
||||
memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id)
|
||||
not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0]
|
||||
if not_allowed_update:
|
||||
return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.")
|
||||
if "memory_type" in to_update:
|
||||
if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type):
|
||||
# update old default prompt, assemble a new one
|
||||
to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]})
|
||||
|
||||
try:
|
||||
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
|
||||
updated_memory = MemoryService.get_by_memory_id(memory_id)
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def delete_memory(memory_id):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(message=True, code=RetCode.NOT_FOUND)
|
||||
try:
|
||||
MemoryService.delete_memory(memory_id)
|
||||
if MessageService.has_index(memory.tenant_id, memory_id):
|
||||
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
|
||||
return get_json_result(message=True)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/memories", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_memory():
|
||||
args = request.args
|
||||
try:
|
||||
tenant_ids = args.getlist("tenant_id")
|
||||
memory_types = args.getlist("memory_type")
|
||||
storage_type = args.get("storage_type")
|
||||
keywords = args.get("keywords", "")
|
||||
page = int(args.get("page", 1))
|
||||
page_size = int(args.get("page_size", 50))
|
||||
# make filter dict
|
||||
filter_dict: dict = {"storage_type": storage_type}
|
||||
if not tenant_ids:
|
||||
# restrict to current user's tenants
|
||||
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
||||
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
||||
else:
|
||||
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
|
||||
tenant_ids = tenant_ids[0].split(',')
|
||||
filter_dict["tenant_id"] = tenant_ids
|
||||
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
|
||||
memory_types = memory_types[0].split(',')
|
||||
filter_dict["memory_type"] = memory_types
|
||||
|
||||
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
||||
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
||||
return get_json_result(message=True, data={"memory_list": memory_list, "total_count": count})
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_config(memory_id):
|
||||
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
|
||||
|
||||
@manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_detail(memory_id):
|
||||
args = request.args
|
||||
agent_ids = args.getlist("agent_id")
|
||||
if len(agent_ids) == 1 and ',' in agent_ids[0]:
|
||||
agent_ids = agent_ids[0].split(',')
|
||||
keywords = args.get("keywords", "")
|
||||
keywords = keywords.strip()
|
||||
page = int(args.get("page", 1))
|
||||
page_size = int(args.get("page_size", 50))
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
messages = MessageService.list_message(
|
||||
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
|
||||
agent_name_mapping = {}
|
||||
extract_task_mapping = {}
|
||||
if messages["message_list"]:
|
||||
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
|
||||
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
|
||||
task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
|
||||
if task_list:
|
||||
task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
|
||||
for task in task_list:
|
||||
# the 'digest' field carries the source_id when a task is created, so use 'digest' as key
|
||||
extract_task_mapping.update({int(task["digest"]): task})
|
||||
for message in messages["message_list"]:
|
||||
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
|
||||
message["task"] = extract_task_mapping.get(message["message_id"], {})
|
||||
for extract_msg in message["extract"]:
|
||||
extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown")
|
||||
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)
|
||||
0
api/apps/services/__init__.py
Normal file
0
api/apps/services/__init__.py
Normal file
223
api/apps/services/memory_api_service.py
Normal file
223
api/apps/services/memory_api_service.py
Normal file
@ -0,0 +1,223 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from api.apps import current_user
|
||||
from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||
from memory.services.messages import MessageService
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
from common.constants import MemoryType, ForgettingPolicy
|
||||
from common.exceptions import ArgumentException, NotFoundException
|
||||
|
||||
|
||||
async def create_memory(memory_info: dict):
|
||||
"""
|
||||
:param memory_info: {
|
||||
"name": str,
|
||||
"memory_type": list[str],
|
||||
"embd_id": str,
|
||||
"llm_id": str
|
||||
}
|
||||
"""
|
||||
# check name length
|
||||
name = memory_info["name"]
|
||||
memory_name = name.strip()
|
||||
if len(memory_name) == 0:
|
||||
raise ArgumentException("Memory name cannot be empty or whitespace.")
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
raise ArgumentException(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||
# check memory_type valid
|
||||
if not isinstance(memory_info["memory_type"], list):
|
||||
raise ArgumentException("Memory type must be a list.")
|
||||
memory_type = set(memory_info["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
if invalid_type:
|
||||
raise ArgumentException(f"Memory type '{invalid_type}' is not supported.")
|
||||
memory_type = list(memory_type)
|
||||
success, res = MemoryService.create_memory(
|
||||
tenant_id=current_user.id,
|
||||
name=memory_name,
|
||||
memory_type=memory_type,
|
||||
embd_id=memory_info["embd_id"],
|
||||
llm_id=memory_info["llm_id"]
|
||||
)
|
||||
if success:
|
||||
return True, format_ret_data_from_memory(res)
|
||||
else:
|
||||
return False, res
|
||||
|
||||
|
||||
async def update_memory(memory_id: str, new_memory_setting: dict):
|
||||
"""
|
||||
:param memory_id: str
|
||||
:param new_memory_setting: {
|
||||
"name": str,
|
||||
"permissions": str,
|
||||
"llm_id": str,
|
||||
"embd_id": str,
|
||||
"memory_type": list[str],
|
||||
"memory_size": int,
|
||||
"forgetting_policy": str,
|
||||
"temperature": float,
|
||||
"avatar": str,
|
||||
"description": str,
|
||||
"system_prompt": str,
|
||||
"user_prompt": str
|
||||
}
|
||||
"""
|
||||
update_dict = {}
|
||||
# check name length
|
||||
if "name" in new_memory_setting:
|
||||
name = new_memory_setting["name"]
|
||||
memory_name = name.strip()
|
||||
if len(memory_name) == 0:
|
||||
raise ArgumentException("Memory name cannot be empty or whitespace.")
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
raise ArgumentException(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||
update_dict["name"] = memory_name
|
||||
# check permissions valid
|
||||
if new_memory_setting.get("permissions"):
|
||||
if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]:
|
||||
raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.")
|
||||
update_dict["permissions"] = new_memory_setting["permissions"]
|
||||
if new_memory_setting.get("llm_id"):
|
||||
update_dict["llm_id"] = new_memory_setting["llm_id"]
|
||||
if new_memory_setting.get("embd_id"):
|
||||
update_dict["embd_id"] = new_memory_setting["embd_id"]
|
||||
if new_memory_setting.get("memory_type"):
|
||||
memory_type = set(new_memory_setting["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
if invalid_type:
|
||||
raise ArgumentException(f"Memory type '{invalid_type}' is not supported.")
|
||||
update_dict["memory_type"] = list(memory_type)
|
||||
# check memory_size valid
|
||||
if new_memory_setting.get("memory_size"):
|
||||
if not 0 < int(new_memory_setting["memory_size"]) <= MEMORY_SIZE_LIMIT:
|
||||
raise ArgumentException(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.")
|
||||
update_dict["memory_size"] = new_memory_setting["memory_size"]
|
||||
# check forgetting_policy valid
|
||||
if new_memory_setting.get("forgetting_policy"):
|
||||
if new_memory_setting["forgetting_policy"] not in [e.value for e in ForgettingPolicy]:
|
||||
raise ArgumentException(f"Forgetting policy '{new_memory_setting['forgetting_policy']}' is not supported.")
|
||||
update_dict["forgetting_policy"] = new_memory_setting["forgetting_policy"]
|
||||
# check temperature valid
|
||||
if "temperature" in new_memory_setting:
|
||||
temperature = float(new_memory_setting["temperature"])
|
||||
if not 0 <= temperature <= 1:
|
||||
raise ArgumentException("Temperature should be in range [0, 1].")
|
||||
update_dict["temperature"] = temperature
|
||||
# allow update to empty fields
|
||||
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
|
||||
if field in new_memory_setting:
|
||||
update_dict[field] = new_memory_setting[field]
|
||||
current_memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not current_memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
|
||||
memory_dict = current_memory.to_dict()
|
||||
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
|
||||
to_update = {}
|
||||
for k, v in update_dict.items():
|
||||
if isinstance(v, list) and set(memory_dict[k]) != set(v):
|
||||
to_update[k] = v
|
||||
elif memory_dict[k] != v:
|
||||
to_update[k] = v
|
||||
|
||||
if not to_update:
|
||||
return True, memory_dict
|
||||
# check memory empty when update embd_id, memory_type
|
||||
memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id)
|
||||
not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0]
|
||||
if not_allowed_update:
|
||||
raise ArgumentException(f"Can't update {not_allowed_update} when memory isn't empty.")
|
||||
if "memory_type" in to_update:
|
||||
if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type):
|
||||
# update old default prompt, assemble a new one
|
||||
to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]})
|
||||
|
||||
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
|
||||
updated_memory = MemoryService.get_by_memory_id(memory_id)
|
||||
return True, format_ret_data_from_memory(updated_memory)
|
||||
|
||||
|
||||
async def delete_memory(memory_id):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
MemoryService.delete_memory(memory_id)
|
||||
if MessageService.has_index(memory.tenant_id, memory_id):
|
||||
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
|
||||
return True
|
||||
|
||||
|
||||
async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size: int = 50):
|
||||
filter_dict: dict = {"storage_type": filter_params.get("storage_type")}
|
||||
tenant_ids = filter_params.get("tenant_id")
|
||||
if not filter_params.get("tenant_id"):
|
||||
# restrict to current user's tenants
|
||||
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
||||
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
||||
else:
|
||||
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
|
||||
tenant_ids = tenant_ids[0].split(',')
|
||||
filter_dict["tenant_id"] = tenant_ids
|
||||
memory_types = filter_params.get("memory_type")
|
||||
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
|
||||
memory_types = memory_types[0].split(',')
|
||||
filter_dict["memory_type"] = memory_types
|
||||
|
||||
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
||||
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
||||
return {
|
||||
"memory_list": memory_list, "total_count": count
|
||||
}
|
||||
|
||||
|
||||
async def get_memory_config(memory_id):
|
||||
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
||||
if not memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
return format_ret_data_from_memory(memory)
|
||||
|
||||
|
||||
async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
messages = MessageService.list_message(
|
||||
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
|
||||
agent_name_mapping = {}
|
||||
extract_task_mapping = {}
|
||||
if messages["message_list"]:
|
||||
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
|
||||
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
|
||||
task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
|
||||
if task_list:
|
||||
task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
|
||||
for task in task_list:
|
||||
# the 'digest' field carries the source_id when a task is created, so use 'digest' as key
|
||||
extract_task_mapping.update({int(task["digest"]): task})
|
||||
for message in messages["message_list"]:
|
||||
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
|
||||
message["task"] = extract_task_mapping.get(message["message_id"], {})
|
||||
for extract_msg in message["extract"]:
|
||||
extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown")
|
||||
return {"messages": messages, "storage_type": memory.storage_type}
|
||||
Reference in New Issue
Block a user