Refactor: split memory API into gateway and service layers (#13111)

### What problem does this PR solve?

Decouple the memory API into a gateway layer (for routing/param parse)
and a service layer (for business logic).

### Type of change

- [x] Refactoring
This commit is contained in:
Lynn
2026-02-12 10:11:50 +08:00
committed by GitHub
parent 4b50b8c579
commit 30d5fc1a07
6 changed files with 413 additions and 292 deletions

View File

@ -244,6 +244,10 @@ def search_pages_path(page_path):
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
]
app_path_list.extend(api_path_list)
restful_api_path_list = [
path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".")
]
app_path_list.extend(restful_api_path_list)
return app_path_list
@ -263,8 +267,9 @@ def register_page(page_path):
spec.loader.exec_module(page)
page_name = getattr(page, "page_name", page_name)
sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/"
restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/"
url_prefix = (
f"/api/{API_VERSION}" if sdk_path in path else f"/{API_VERSION}/{page_name}"
f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}"
)
app.register_blueprint(page.manager, url_prefix=url_prefix)
@ -274,6 +279,7 @@ def register_page(page_path):
pages_dir = [
Path(__file__).parent,
Path(__file__).parent.parent / "api" / "apps",
Path(__file__).parent.parent / "api" / "apps" / "restful_apis",
Path(__file__).parent.parent / "api" / "apps" / "sdk",
]

View File

@ -0,0 +1,173 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from quart import request
from common.constants import RetCode
from common.exceptions import ArgumentException, NotFoundException
from api.apps import login_required
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from api.apps.services import memory_api_service
@manager.route("/memories", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "memory_type", "embd_id", "llm_id")
async def create_memory():
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
req = await get_request_json()
t_parsed = time.perf_counter() if timing_enabled else None
try:
memory_info = {
"name": req["name"],
"memory_type": req["memory_type"],
"embd_id": req["embd_id"],
"llm_id": req["llm_id"]
}
success, res = await memory_api_service.create_memory(memory_info)
if timing_enabled:
logging.info(
"api_timing create_memory parse_ms=%.2f validate_and_db_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_parsed) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
if success:
return get_json_result(message=True, data=res)
else:
return get_json_result(message=res, code=RetCode.SERVER_ERROR)
except ArgumentException as arg_error:
logging.error(arg_error)
if timing_enabled:
logging.info(
"api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s",
str(arg_error),
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(str(arg_error))
except Exception as e:
logging.error(e)
if timing_enabled:
logging.info(
"api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s",
str(e),
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
@manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
@login_required
async def update_memory(memory_id):
req = await get_request_json()
new_settings = {k: req[k] for k in [
"name", "permissions", "llm_id", "embd_id", "memory_type", "memory_size", "forgetting_policy", "temperature",
"avatar", "description", "system_prompt", "user_prompt"
] if k in req}
try:
success, res = await memory_api_service.update_memory(memory_id, new_settings)
if success:
return get_json_result(message=True, data=res)
else:
return get_json_result(message=res, code=RetCode.SERVER_ERROR)
except NotFoundException as not_found_exception:
logging.error(not_found_exception)
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
except ArgumentException as arg_error:
logging.error(arg_error)
return get_error_argument_result(str(arg_error))
except Exception as e:
logging.error(e)
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
@manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_memory(memory_id):
try:
await memory_api_service.delete_memory(memory_id)
return get_json_result(message=True)
except NotFoundException as not_found_exception:
logging.error(not_found_exception)
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
except Exception as e:
logging.error(e)
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
@manager.route("/memories", methods=["GET"]) # noqa: F821
@login_required
async def list_memory():
filter_params = {
k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args
}
keywords = request.args.get("keywords")
page = int(request.args.get("page", 1))
page_size = int(request.args.get("page_size", 50))
try:
res = await memory_api_service.list_memory(filter_params, keywords, page, page_size)
return get_json_result(message=True, data=res)
except Exception as e:
logging.error(e)
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
@manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_config(memory_id):
try:
res = await memory_api_service.get_memory_config(memory_id)
return get_json_result(message=True, data=res)
except NotFoundException as not_found_exception:
logging.error(not_found_exception)
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
except Exception as e:
logging.error(e)
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
@manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_messages(memory_id):
args = request.args
agent_ids = args.getlist("agent_id")
if len(agent_ids) == 1 and ',' in agent_ids[0]:
agent_ids = agent_ids[0].split(',')
keywords = args.get("keywords", "")
keywords = keywords.strip()
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
try:
res = await memory_api_service.get_memory_messages(
memory_id, agent_ids, keywords, page, page_size
)
return get_json_result(message=True, data=res)
except NotFoundException as not_found_exception:
logging.error(not_found_exception)
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
except Exception as e:
logging.error(e)
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")

View File

@ -1,291 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from quart import request
from api.apps import login_required, current_user
from api.db import TenantPermission
from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.task_service import TaskService
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
from memory.services.messages import MessageService
from memory.utils.prompt_util import PromptAssembler
from common.constants import MemoryType, RetCode, ForgettingPolicy
@manager.route("/memories", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "memory_type", "embd_id", "llm_id")
async def create_memory():
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
req = await get_request_json()
t_parsed = time.perf_counter() if timing_enabled else None
# check name length
name = req["name"]
memory_name = name.strip()
if len(memory_name) == 0:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory name cannot be empty or whitespace.")
if len(memory_name) > MEMORY_NAME_LIMIT:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
# check memory_type valid
if not isinstance(req["memory_type"], list):
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory type must be a list.")
memory_type = set(req["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
memory_type = list(memory_type)
try:
t_before_db = time.perf_counter() if timing_enabled else None
res, memory = MemoryService.create_memory(
tenant_id=current_user.id,
name=memory_name,
memory_type=memory_type,
embd_id=req["embd_id"],
llm_id=req["llm_id"]
)
if timing_enabled:
logging.info(
"api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(t_before_db - t_parsed) * 1000,
(time.perf_counter() - t_before_db) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
else:
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
except Exception as e:
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
@login_required
async def update_memory(memory_id):
req = await get_request_json()
update_dict = {}
# check name length
if "name" in req:
name = req["name"]
memory_name = name.strip()
if len(memory_name) == 0:
return get_error_argument_result("Memory name cannot be empty or whitespace.")
if len(memory_name) > MEMORY_NAME_LIMIT:
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
update_dict["name"] = memory_name
# check permissions valid
if req.get("permissions"):
if req["permissions"] not in [e.value for e in TenantPermission]:
return get_error_argument_result(f"Unknown permission '{req['permissions']}'.")
update_dict["permissions"] = req["permissions"]
if req.get("llm_id"):
update_dict["llm_id"] = req["llm_id"]
if req.get("embd_id"):
update_dict["embd_id"] = req["embd_id"]
if req.get("memory_type"):
memory_type = set(req["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type:
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
update_dict["memory_type"] = list(memory_type)
# check memory_size valid
if req.get("memory_size"):
if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT:
return get_error_argument_result(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.")
update_dict["memory_size"] = req["memory_size"]
# check forgetting_policy valid
if req.get("forgetting_policy"):
if req["forgetting_policy"] not in [e.value for e in ForgettingPolicy]:
return get_error_argument_result(f"Forgetting policy '{req['forgetting_policy']}' is not supported.")
update_dict["forgetting_policy"] = req["forgetting_policy"]
# check temperature valid
if "temperature" in req:
temperature = float(req["temperature"])
if not 0 <= temperature <= 1:
return get_error_argument_result("Temperature should be in range [0, 1].")
update_dict["temperature"] = temperature
# allow update to empty fields
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
if field in req:
update_dict[field] = req[field]
current_memory = MemoryService.get_by_memory_id(memory_id)
if not current_memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
memory_dict = current_memory.to_dict()
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
to_update = {}
for k, v in update_dict.items():
if isinstance(v, list) and set(memory_dict[k]) != set(v):
to_update[k] = v
elif memory_dict[k] != v:
to_update[k] = v
if not to_update:
return get_json_result(message=True, data=memory_dict)
# check memory empty when update embd_id, memory_type
memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id)
not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0]
if not_allowed_update:
return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.")
if "memory_type" in to_update:
if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type):
# update old default prompt, assemble a new one
to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]})
try:
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
updated_memory = MemoryService.get_by_memory_id(memory_id)
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
except Exception as e:
logging.error(e)
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(message=True, code=RetCode.NOT_FOUND)
try:
MemoryService.delete_memory(memory_id)
if MessageService.has_index(memory.tenant_id, memory_id):
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
return get_json_result(message=True)
except Exception as e:
logging.error(e)
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/memories", methods=["GET"]) # noqa: F821
@login_required
async def list_memory():
args = request.args
try:
tenant_ids = args.getlist("tenant_id")
memory_types = args.getlist("memory_type")
storage_type = args.get("storage_type")
keywords = args.get("keywords", "")
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
# make filter dict
filter_dict: dict = {"storage_type": storage_type}
if not tenant_ids:
# restrict to current user's tenants
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
else:
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
tenant_ids = tenant_ids[0].split(',')
filter_dict["tenant_id"] = tenant_ids
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
memory_types = memory_types[0].split(',')
filter_dict["memory_type"] = memory_types
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
return get_json_result(message=True, data={"memory_list": memory_list, "total_count": count})
except Exception as e:
logging.error(e)
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_config(memory_id):
memory = MemoryService.get_with_owner_name_by_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_detail(memory_id):
args = request.args
agent_ids = args.getlist("agent_id")
if len(agent_ids) == 1 and ',' in agent_ids[0]:
agent_ids = agent_ids[0].split(',')
keywords = args.get("keywords", "")
keywords = keywords.strip()
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
extract_task_mapping = {}
if messages["message_list"]:
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
if task_list:
task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
for task in task_list:
# the 'digest' field carries the source_id when a task is created, so use 'digest' as key
extract_task_mapping.update({int(task["digest"]): task})
for message in messages["message_list"]:
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
message["task"] = extract_task_mapping.get(message["message_id"], {})
for extract_msg in message["extract"]:
extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown")
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)

View File

View File

@ -0,0 +1,223 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.apps import current_user
from api.db import TenantPermission
from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.task_service import TaskService
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
from memory.services.messages import MessageService
from memory.utils.prompt_util import PromptAssembler
from common.constants import MemoryType, ForgettingPolicy
from common.exceptions import ArgumentException, NotFoundException
async def create_memory(memory_info: dict):
"""
:param memory_info: {
"name": str,
"memory_type": list[str],
"embd_id": str,
"llm_id": str
}
"""
# check name length
name = memory_info["name"]
memory_name = name.strip()
if len(memory_name) == 0:
raise ArgumentException("Memory name cannot be empty or whitespace.")
if len(memory_name) > MEMORY_NAME_LIMIT:
raise ArgumentException(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
# check memory_type valid
if not isinstance(memory_info["memory_type"], list):
raise ArgumentException("Memory type must be a list.")
memory_type = set(memory_info["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type:
raise ArgumentException(f"Memory type '{invalid_type}' is not supported.")
memory_type = list(memory_type)
success, res = MemoryService.create_memory(
tenant_id=current_user.id,
name=memory_name,
memory_type=memory_type,
embd_id=memory_info["embd_id"],
llm_id=memory_info["llm_id"]
)
if success:
return True, format_ret_data_from_memory(res)
else:
return False, res
async def update_memory(memory_id: str, new_memory_setting: dict):
"""
:param memory_id: str
:param new_memory_setting: {
"name": str,
"permissions": str,
"llm_id": str,
"embd_id": str,
"memory_type": list[str],
"memory_size": int,
"forgetting_policy": str,
"temperature": float,
"avatar": str,
"description": str,
"system_prompt": str,
"user_prompt": str
}
"""
update_dict = {}
# check name length
if "name" in new_memory_setting:
name = new_memory_setting["name"]
memory_name = name.strip()
if len(memory_name) == 0:
raise ArgumentException("Memory name cannot be empty or whitespace.")
if len(memory_name) > MEMORY_NAME_LIMIT:
raise ArgumentException(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
update_dict["name"] = memory_name
# check permissions valid
if new_memory_setting.get("permissions"):
if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]:
raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.")
update_dict["permissions"] = new_memory_setting["permissions"]
if new_memory_setting.get("llm_id"):
update_dict["llm_id"] = new_memory_setting["llm_id"]
if new_memory_setting.get("embd_id"):
update_dict["embd_id"] = new_memory_setting["embd_id"]
if new_memory_setting.get("memory_type"):
memory_type = set(new_memory_setting["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type:
raise ArgumentException(f"Memory type '{invalid_type}' is not supported.")
update_dict["memory_type"] = list(memory_type)
# check memory_size valid
if new_memory_setting.get("memory_size"):
if not 0 < int(new_memory_setting["memory_size"]) <= MEMORY_SIZE_LIMIT:
raise ArgumentException(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.")
update_dict["memory_size"] = new_memory_setting["memory_size"]
# check forgetting_policy valid
if new_memory_setting.get("forgetting_policy"):
if new_memory_setting["forgetting_policy"] not in [e.value for e in ForgettingPolicy]:
raise ArgumentException(f"Forgetting policy '{new_memory_setting['forgetting_policy']}' is not supported.")
update_dict["forgetting_policy"] = new_memory_setting["forgetting_policy"]
# check temperature valid
if "temperature" in new_memory_setting:
temperature = float(new_memory_setting["temperature"])
if not 0 <= temperature <= 1:
raise ArgumentException("Temperature should be in range [0, 1].")
update_dict["temperature"] = temperature
# allow update to empty fields
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
if field in new_memory_setting:
update_dict[field] = new_memory_setting[field]
current_memory = MemoryService.get_by_memory_id(memory_id)
if not current_memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
memory_dict = current_memory.to_dict()
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
to_update = {}
for k, v in update_dict.items():
if isinstance(v, list) and set(memory_dict[k]) != set(v):
to_update[k] = v
elif memory_dict[k] != v:
to_update[k] = v
if not to_update:
return True, memory_dict
# check memory empty when update embd_id, memory_type
memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id)
not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0]
if not_allowed_update:
raise ArgumentException(f"Can't update {not_allowed_update} when memory isn't empty.")
if "memory_type" in to_update:
if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type):
# update old default prompt, assemble a new one
to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]})
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
updated_memory = MemoryService.get_by_memory_id(memory_id)
return True, format_ret_data_from_memory(updated_memory)
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
MemoryService.delete_memory(memory_id)
if MessageService.has_index(memory.tenant_id, memory_id):
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
return True
async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size: int = 50):
filter_dict: dict = {"storage_type": filter_params.get("storage_type")}
tenant_ids = filter_params.get("tenant_id")
if not filter_params.get("tenant_id"):
# restrict to current user's tenants
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
else:
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
tenant_ids = tenant_ids[0].split(',')
filter_dict["tenant_id"] = tenant_ids
memory_types = filter_params.get("memory_type")
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
memory_types = memory_types[0].split(',')
filter_dict["memory_type"] = memory_types
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
return {
"memory_list": memory_list, "total_count": count
}
async def get_memory_config(memory_id):
memory = MemoryService.get_with_owner_name_by_id(memory_id)
if not memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
return format_ret_data_from_memory(memory)
async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
extract_task_mapping = {}
if messages["message_list"]:
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
if task_list:
task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
for task in task_list:
# the 'digest' field carries the source_id when a task is created, so use 'digest' as key
extract_task_mapping.update({int(task["digest"]): task})
for message in messages["message_list"]:
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
message["task"] = extract_task_mapping.get(message["message_id"], {})
for extract_msg in message["extract"]:
extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown")
return {"messages": messages, "storage_type": memory.storage_type}