mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-23 01:18:22 +08:00
### What problem does this PR solve? This PR fixes missing authorization checks in the Memory API. Previously, several authenticated endpoints accepted caller-supplied `tenant_id`, `owner_ids`, or `memory_id` values and used them directly to list, read, update, delete, or search Memory data. That could allow an authenticated user to access or mutate another tenant's Memory records if they knew a tenant ID or memory ID. The fix centralizes Memory access checks and applies them consistently across Memory and Memory-message operations. The change: - Adds helper logic to parse list filters and compute tenant IDs accessible to `current_user`. - Requires direct `memory_id` operations to pass Memory access checks before reading, updating, deleting, or changing message state. - Filters list/search/recent-message requests to accessible memories only. - Applies Memory visibility filtering before count and pagination in `MemoryService.get_by_filter`. - Accepts `owner_ids` in the Memory list route, matching the frontend owner filter while still intersecting values with the caller's accessible tenants. - ### Related issues Closes #14534 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: jony376 <jony376@gmail.com>
305 lines
12 KiB
Python
305 lines
12 KiB
Python
#
|
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import logging
|
|
import os
|
|
import time
|
|
|
|
from quart import request
|
|
from common.constants import LLMType, RetCode
|
|
from common.exceptions import ArgumentException, NotFoundException
|
|
from api.apps import login_required, current_user
|
|
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
|
from api.apps.services import memory_api_service
|
|
from api.utils.tenant_utils import ensure_tenant_model_id_for_params
|
|
|
|
|
|
@manager.route("/memories", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
|
async def create_memory():
|
|
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
|
t_start = time.perf_counter() if timing_enabled else None
|
|
req = await get_request_json()
|
|
t_parsed = time.perf_counter() if timing_enabled else None
|
|
try:
|
|
req = ensure_tenant_model_id_for_params(current_user.id, req)
|
|
if not req.get("tenant_llm_id"):
|
|
raise ArgumentException(
|
|
f"Tenant Model with name {req['llm_id']} and type {LLMType.CHAT.value} not found"
|
|
)
|
|
memory_info = {
|
|
"name": req["name"],
|
|
"memory_type": req["memory_type"],
|
|
"embd_id": req["embd_id"],
|
|
"llm_id": req["llm_id"],
|
|
"tenant_embd_id": req["tenant_embd_id"],
|
|
"tenant_llm_id": req["tenant_llm_id"],
|
|
}
|
|
success, res = await memory_api_service.create_memory(memory_info)
|
|
if timing_enabled:
|
|
logging.info(
|
|
"api_timing create_memory parse_ms=%.2f validate_and_db_ms=%.2f total_ms=%.2f path=%s",
|
|
(t_parsed - t_start) * 1000,
|
|
(time.perf_counter() - t_parsed) * 1000,
|
|
(time.perf_counter() - t_start) * 1000,
|
|
request.path,
|
|
)
|
|
if success:
|
|
return get_json_result(message=True, data=res)
|
|
else:
|
|
return get_json_result(message=res, code=RetCode.SERVER_ERROR)
|
|
|
|
except ArgumentException as arg_error:
|
|
logging.error(arg_error)
|
|
if timing_enabled:
|
|
logging.info(
|
|
"api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s",
|
|
str(arg_error),
|
|
(t_parsed - t_start) * 1000,
|
|
(time.perf_counter() - t_start) * 1000,
|
|
request.path,
|
|
)
|
|
return get_error_argument_result(str(arg_error))
|
|
|
|
except Exception as e:
|
|
logging.error(e)
|
|
if timing_enabled:
|
|
logging.info(
|
|
"api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s",
|
|
str(e),
|
|
(t_parsed - t_start) * 1000,
|
|
(time.perf_counter() - t_start) * 1000,
|
|
request.path,
|
|
)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
|
|
@login_required
|
|
async def update_memory(memory_id):
|
|
req = await get_request_json()
|
|
new_settings = {k: req[k] for k in [
|
|
"name", "permissions", "llm_id", "embd_id", "memory_type", "memory_size", "forgetting_policy", "temperature",
|
|
"avatar", "description", "system_prompt", "user_prompt", "tenant_llm_id", "tenant_embd_id"
|
|
] if k in req}
|
|
try:
|
|
success, res = await memory_api_service.update_memory(memory_id, new_settings)
|
|
if success:
|
|
return get_json_result(message=True, data=res)
|
|
else:
|
|
return get_json_result(message=res, code=RetCode.SERVER_ERROR)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except ArgumentException as arg_error:
|
|
logging.error(arg_error)
|
|
return get_error_argument_result(str(arg_error))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
|
|
@login_required
|
|
async def delete_memory(memory_id):
|
|
try:
|
|
await memory_api_service.delete_memory(memory_id)
|
|
return get_json_result(message=True)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def list_memory():
|
|
filter_params = {
|
|
k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args
|
|
}
|
|
keywords = request.args.get("keywords")
|
|
page = int(request.args.get("page", 1))
|
|
page_size = int(request.args.get("page_size", 50))
|
|
try:
|
|
res = await memory_api_service.list_memory(filter_params, keywords, page, page_size)
|
|
return get_json_result(message=True, data=res)
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def get_memory_config(memory_id):
|
|
try:
|
|
res = await memory_api_service.get_memory_config(memory_id)
|
|
return get_json_result(message=True, data=res)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def get_memory_messages(memory_id):
|
|
args = request.args
|
|
agent_ids = args.getlist("agent_id")
|
|
if len(agent_ids) == 1 and ',' in agent_ids[0]:
|
|
agent_ids = agent_ids[0].split(',')
|
|
keywords = args.get("keywords", "")
|
|
keywords = keywords.strip()
|
|
page = int(args.get("page", 1))
|
|
page_size = int(args.get("page_size", 50))
|
|
try:
|
|
res = await memory_api_service.get_memory_messages(
|
|
memory_id, agent_ids, keywords, page, page_size
|
|
)
|
|
return get_json_result(message=True, data=res)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/messages", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
|
|
async def add_message():
|
|
req = await get_request_json()
|
|
memory_ids = req["memory_id"]
|
|
|
|
message_dict = {
|
|
"user_id": req.get("user_id"),
|
|
"agent_id": req["agent_id"],
|
|
"session_id": req["session_id"],
|
|
"user_input": req["user_input"],
|
|
"agent_response": req["agent_response"],
|
|
}
|
|
|
|
res, msg = await memory_api_service.add_message(memory_ids, message_dict)
|
|
if res:
|
|
return get_json_result(message=msg)
|
|
|
|
return get_json_result(message="Some messages failed to add. Detail:" + msg, code=RetCode.SERVER_ERROR)
|
|
|
|
|
|
@manager.route("/messages/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
|
|
@login_required
|
|
async def forget_message(memory_id: str, message_id: int):
|
|
try:
|
|
res = await memory_api_service.forget_message(memory_id, message_id)
|
|
return get_json_result(message=res)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/messages/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("status")
|
|
async def update_message(memory_id: str, message_id: int):
|
|
req = await get_request_json()
|
|
status = req["status"]
|
|
if not isinstance(status, bool):
|
|
return get_error_argument_result("Status must be a boolean.")
|
|
|
|
try:
|
|
update_succeed = await memory_api_service.update_message_status(memory_id, message_id, status)
|
|
if update_succeed:
|
|
return get_json_result(message=update_succeed)
|
|
else:
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/messages/search", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def search_message():
|
|
args = request.args
|
|
memory_ids = args.getlist("memory_id")
|
|
if len(memory_ids) == 1 and ',' in memory_ids[0]:
|
|
memory_ids = memory_ids[0].split(',')
|
|
query = args.get("query")
|
|
similarity_threshold = float(args.get("similarity_threshold", 0.2))
|
|
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
|
|
top_n = int(args.get("top_n", 5))
|
|
agent_id = args.get("agent_id", "")
|
|
session_id = args.get("session_id", "")
|
|
user_id = args.get("user_id", "")
|
|
|
|
filter_dict = {
|
|
"memory_id": memory_ids,
|
|
"agent_id": agent_id,
|
|
"session_id": session_id,
|
|
"user_id": user_id
|
|
}
|
|
params = {
|
|
"query": query,
|
|
"similarity_threshold": similarity_threshold,
|
|
"keywords_similarity_weight": keywords_similarity_weight,
|
|
"top_n": top_n
|
|
}
|
|
res = await memory_api_service.search_message(filter_dict, params)
|
|
return get_json_result(message=True, data=res)
|
|
|
|
@manager.route("/messages", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def get_messages():
|
|
args = request.args
|
|
memory_ids = args.getlist("memory_id")
|
|
if len(memory_ids) == 1 and ',' in memory_ids[0]:
|
|
memory_ids = memory_ids[0].split(',')
|
|
agent_id = args.get("agent_id", "")
|
|
session_id = args.get("session_id", "")
|
|
limit = int(args.get("limit", 10))
|
|
if not memory_ids:
|
|
return get_error_argument_result("memory_ids is required.")
|
|
try:
|
|
res = await memory_api_service.get_messages(memory_ids, agent_id, session_id, limit)
|
|
return get_json_result(message=True, data=res)
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/messages/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def get_message_content(memory_id: str, message_id: int):
|
|
try:
|
|
res = await memory_api_service.get_message_content(memory_id, message_id)
|
|
return get_json_result(message=True, data=res)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|