mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-25 10:26:59 +08:00
### Related issues Closes #14744 ### What problem does this PR solve? The Memory REST endpoint `POST /api/v1/messages` previously persisted whatever `user_id` the client sent in the JSON body. Memory rows were therefore attributed to an arbitrary string, even when the caller authenticated as a normal workspace user via JWT (browser/session-style bearer token decoded into an access token). That broke attribution and audit semantics for shared memories (team visibility): any authorized writer could spoof another subject id. The Python SDK already sends an optional `user_id` for integrations using **API keys** (`APIToken`) to tag an external subject distinct from the tenant owner user. ### Solution - Record **`g.auth_via_api_token`** in `_load_user` (`api/apps/__init__.py`): set `True` only when authentication resolves via `APIToken`, otherwise `False` after JWT-based login succeeds. - In **`POST /messages`** (`memory_api.add_message`): if the request was authenticated with an API key, keep accepting optional `user_id` from the body (default empty string). For JWT-authenticated users, **always** set stored `user_id` to **`current_user.id`** and ignore the client field. - Guard reads of `g` with **`RuntimeError`** handling so isolated imports or tests without a Quart application context do not fail when resolving `user_id`. - Document on **`RAGFlow.add_message`** that `user_id` is only meaningful for API-key authentication. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): ### Testing - `python -m py_compile` on modified modules (`api/apps/__init__.py`, `api/apps/restful_apis/memory_api.py`). - Recommended: run web/SDK memory message tests (`test_add_message`, `test_message_routes_unit`) against a full environment with `quart` and configured services. ### Notes for reviewers - Behavior change **only** for callers using JWT-style authorization on `POST /messages`; API-key callers keep prior optional `user_id` semantics. Co-authored-by: jony376 <jony376@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
315 lines
12 KiB
Python
315 lines
12 KiB
Python
#
|
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import logging
|
|
import os
|
|
import time
|
|
|
|
from quart import request, g
|
|
from common.constants import LLMType, RetCode
|
|
from common.exceptions import ArgumentException, NotFoundException
|
|
from api.apps import login_required, current_user
|
|
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
|
from api.apps.services import memory_api_service
|
|
from api.utils.tenant_utils import ensure_tenant_model_id_for_params
|
|
|
|
|
|
@manager.route("/memories", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
|
async def create_memory():
|
|
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
|
t_start = time.perf_counter() if timing_enabled else None
|
|
req = await get_request_json()
|
|
t_parsed = time.perf_counter() if timing_enabled else None
|
|
try:
|
|
req = ensure_tenant_model_id_for_params(current_user.id, req)
|
|
if not req.get("tenant_llm_id"):
|
|
raise ArgumentException(
|
|
f"Tenant Model with name {req['llm_id']} and type {LLMType.CHAT.value} not found"
|
|
)
|
|
memory_info = {
|
|
"name": req["name"],
|
|
"memory_type": req["memory_type"],
|
|
"embd_id": req["embd_id"],
|
|
"llm_id": req["llm_id"],
|
|
"tenant_embd_id": req["tenant_embd_id"],
|
|
"tenant_llm_id": req["tenant_llm_id"],
|
|
}
|
|
success, res = await memory_api_service.create_memory(memory_info)
|
|
if timing_enabled:
|
|
logging.info(
|
|
"api_timing create_memory parse_ms=%.2f validate_and_db_ms=%.2f total_ms=%.2f path=%s",
|
|
(t_parsed - t_start) * 1000,
|
|
(time.perf_counter() - t_parsed) * 1000,
|
|
(time.perf_counter() - t_start) * 1000,
|
|
request.path,
|
|
)
|
|
if success:
|
|
return get_json_result(message=True, data=res)
|
|
else:
|
|
return get_json_result(message=res, code=RetCode.SERVER_ERROR)
|
|
|
|
except ArgumentException as arg_error:
|
|
logging.error(arg_error)
|
|
if timing_enabled:
|
|
logging.info(
|
|
"api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s",
|
|
str(arg_error),
|
|
(t_parsed - t_start) * 1000,
|
|
(time.perf_counter() - t_start) * 1000,
|
|
request.path,
|
|
)
|
|
return get_error_argument_result(str(arg_error))
|
|
|
|
except Exception as e:
|
|
logging.error(e)
|
|
if timing_enabled:
|
|
logging.info(
|
|
"api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s",
|
|
str(e),
|
|
(t_parsed - t_start) * 1000,
|
|
(time.perf_counter() - t_start) * 1000,
|
|
request.path,
|
|
)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
|
|
@login_required
|
|
async def update_memory(memory_id):
|
|
req = await get_request_json()
|
|
new_settings = {k: req[k] for k in [
|
|
"name", "permissions", "llm_id", "embd_id", "memory_type", "memory_size", "forgetting_policy", "temperature",
|
|
"avatar", "description", "system_prompt", "user_prompt", "tenant_llm_id", "tenant_embd_id"
|
|
] if k in req}
|
|
try:
|
|
success, res = await memory_api_service.update_memory(memory_id, new_settings)
|
|
if success:
|
|
return get_json_result(message=True, data=res)
|
|
else:
|
|
return get_json_result(message=res, code=RetCode.SERVER_ERROR)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except ArgumentException as arg_error:
|
|
logging.error(arg_error)
|
|
return get_error_argument_result(str(arg_error))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
|
|
@login_required
|
|
async def delete_memory(memory_id):
|
|
try:
|
|
await memory_api_service.delete_memory(memory_id)
|
|
return get_json_result(message=True)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def list_memory():
|
|
filter_params = {
|
|
k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args
|
|
}
|
|
keywords = request.args.get("keywords")
|
|
page = int(request.args.get("page", 1))
|
|
page_size = int(request.args.get("page_size", 50))
|
|
try:
|
|
res = await memory_api_service.list_memory(filter_params, keywords, page, page_size)
|
|
return get_json_result(message=True, data=res)
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def get_memory_config(memory_id):
|
|
try:
|
|
res = await memory_api_service.get_memory_config(memory_id)
|
|
return get_json_result(message=True, data=res)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def get_memory_messages(memory_id):
|
|
args = request.args
|
|
agent_ids = args.getlist("agent_id")
|
|
if len(agent_ids) == 1 and ',' in agent_ids[0]:
|
|
agent_ids = agent_ids[0].split(',')
|
|
keywords = args.get("keywords", "")
|
|
keywords = keywords.strip()
|
|
page = int(args.get("page", 1))
|
|
page_size = int(args.get("page_size", 50))
|
|
try:
|
|
res = await memory_api_service.get_memory_messages(
|
|
memory_id, agent_ids, keywords, page, page_size
|
|
)
|
|
return get_json_result(message=True, data=res)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/messages", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
|
|
async def add_message():
|
|
req = await get_request_json()
|
|
memory_ids = req["memory_id"]
|
|
|
|
# JWT / session users cannot spoof attribution; API-key callers may supply an external subject id.
|
|
try:
|
|
trust_client_subject = bool(getattr(g, "auth_via_api_token", False))
|
|
except RuntimeError:
|
|
trust_client_subject = False
|
|
if trust_client_subject:
|
|
effective_user_id = req.get("user_id", "")
|
|
else:
|
|
effective_user_id = current_user.id
|
|
|
|
message_dict = {
|
|
"user_id": effective_user_id,
|
|
"agent_id": req["agent_id"],
|
|
"session_id": req["session_id"],
|
|
"user_input": req["user_input"],
|
|
"agent_response": req["agent_response"],
|
|
}
|
|
|
|
res, msg = await memory_api_service.add_message(memory_ids, message_dict)
|
|
if res:
|
|
return get_json_result(message=msg)
|
|
|
|
return get_json_result(message="Some messages failed to add. Detail:" + msg, code=RetCode.SERVER_ERROR)
|
|
|
|
|
|
@manager.route("/messages/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
|
|
@login_required
|
|
async def forget_message(memory_id: str, message_id: int):
|
|
try:
|
|
res = await memory_api_service.forget_message(memory_id, message_id)
|
|
return get_json_result(message=res)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/messages/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("status")
|
|
async def update_message(memory_id: str, message_id: int):
|
|
req = await get_request_json()
|
|
status = req["status"]
|
|
if not isinstance(status, bool):
|
|
return get_error_argument_result("Status must be a boolean.")
|
|
|
|
try:
|
|
update_succeed = await memory_api_service.update_message_status(memory_id, message_id, status)
|
|
if update_succeed:
|
|
return get_json_result(message=update_succeed)
|
|
else:
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/messages/search", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def search_message():
|
|
args = request.args
|
|
memory_ids = args.getlist("memory_id")
|
|
if len(memory_ids) == 1 and ',' in memory_ids[0]:
|
|
memory_ids = memory_ids[0].split(',')
|
|
query = args.get("query")
|
|
similarity_threshold = float(args.get("similarity_threshold", 0.2))
|
|
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
|
|
top_n = int(args.get("top_n", 5))
|
|
agent_id = args.get("agent_id", "")
|
|
session_id = args.get("session_id", "")
|
|
user_id = args.get("user_id", "")
|
|
|
|
filter_dict = {
|
|
"memory_id": memory_ids,
|
|
"agent_id": agent_id,
|
|
"session_id": session_id,
|
|
"user_id": user_id
|
|
}
|
|
params = {
|
|
"query": query,
|
|
"similarity_threshold": similarity_threshold,
|
|
"keywords_similarity_weight": keywords_similarity_weight,
|
|
"top_n": top_n
|
|
}
|
|
res = await memory_api_service.search_message(filter_dict, params)
|
|
return get_json_result(message=True, data=res)
|
|
|
|
@manager.route("/messages", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def get_messages():
|
|
args = request.args
|
|
memory_ids = args.getlist("memory_id")
|
|
if len(memory_ids) == 1 and ',' in memory_ids[0]:
|
|
memory_ids = memory_ids[0].split(',')
|
|
agent_id = args.get("agent_id", "")
|
|
session_id = args.get("session_id", "")
|
|
limit = int(args.get("limit", 10))
|
|
if not memory_ids:
|
|
return get_error_argument_result("memory_ids is required.")
|
|
try:
|
|
res = await memory_api_service.get_messages(memory_ids, agent_id, session_id, limit)
|
|
return get_json_result(message=True, data=res)
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|
|
|
|
|
|
@manager.route("/messages/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
async def get_message_content(memory_id: str, message_id: int):
|
|
try:
|
|
res = await memory_api_service.get_message_content(memory_id, message_id)
|
|
return get_json_result(message=True, data=res)
|
|
except NotFoundException as not_found_exception:
|
|
logging.error(not_found_exception)
|
|
return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception))
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error")
|