mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-20 16:26:42 +08:00
### What problem does this PR solve? This PR fixes missing authorization checks in the Memory API. Previously, several authenticated endpoints accepted caller-supplied `tenant_id`, `owner_ids`, or `memory_id` values and used them directly to list, read, update, delete, or search Memory data. That could allow an authenticated user to access or mutate another tenant's Memory records if they knew a tenant ID or memory ID. The fix centralizes Memory access checks and applies them consistently across Memory and Memory-message operations. The change: - Adds helper logic to parse list filters and compute tenant IDs accessible to `current_user`. - Requires direct `memory_id` operations to pass Memory access checks before reading, updating, deleting, or changing message state. - Filters list/search/recent-message requests to accessible memories only. - Applies Memory visibility filtering before count and pagination in `MemoryService.get_by_filter`. - Accepts `owner_ids` in the Memory list route, matching the frontend owner filter while still intersecting values with the caller's accessible tenants. - ### Related issues Closes #14534 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: jony376 <jony376@gmail.com>
200 lines
7.2 KiB
Python
200 lines
7.2 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
from typing import List
|
|
|
|
from api.db.db_models import DB, Memory, User
|
|
from api.db.services import duplicate_name
|
|
from api.db.services.common_service import CommonService
|
|
from api.utils.memory_utils import calculate_memory_type
|
|
from api.constants import MEMORY_NAME_LIMIT
|
|
from common.misc_utils import get_uuid
|
|
from common.time_utils import get_format_time, current_timestamp
|
|
from memory.utils.prompt_util import PromptAssembler
|
|
|
|
|
|
class MemoryService(CommonService):
|
|
# Service class for manage memory operations
|
|
model = Memory
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_by_memory_id(cls, memory_id: str):
|
|
return cls.model.select().where(cls.model.id == memory_id).first()
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_by_tenant_id(cls, tenant_id: str):
|
|
return cls.model.select().where(cls.model.tenant_id == tenant_id)
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_all_memory(cls):
|
|
memory_list = cls.model.select()
|
|
return list(memory_list)
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_with_owner_name_by_id(cls, memory_id: str):
|
|
fields = [
|
|
cls.model.id,
|
|
cls.model.name,
|
|
cls.model.avatar,
|
|
cls.model.tenant_id,
|
|
User.nickname.alias("owner_name"),
|
|
cls.model.memory_type,
|
|
cls.model.storage_type,
|
|
cls.model.embd_id,
|
|
cls.model.llm_id,
|
|
cls.model.permissions,
|
|
cls.model.description,
|
|
cls.model.memory_size,
|
|
cls.model.forgetting_policy,
|
|
cls.model.temperature,
|
|
cls.model.system_prompt,
|
|
cls.model.user_prompt,
|
|
cls.model.create_date,
|
|
cls.model.create_time
|
|
]
|
|
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
|
|
cls.model.id == memory_id
|
|
).first()
|
|
return memory
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_by_filter(cls, filter_dict: dict, keywords: str, page: int = 1, page_size: int = 50):
|
|
fields = [
|
|
cls.model.id,
|
|
cls.model.name,
|
|
cls.model.avatar,
|
|
cls.model.tenant_id,
|
|
User.nickname.alias("owner_name"),
|
|
cls.model.memory_type,
|
|
cls.model.storage_type,
|
|
cls.model.permissions,
|
|
cls.model.description,
|
|
cls.model.create_time,
|
|
cls.model.create_date
|
|
]
|
|
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
|
|
if filter_dict.get("tenant_id"):
|
|
memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"]))
|
|
if filter_dict.get("accessible_user_id"):
|
|
memories = memories.where(
|
|
(cls.model.tenant_id == filter_dict["accessible_user_id"]) |
|
|
(cls.model.permissions == "team")
|
|
)
|
|
if filter_dict.get("memory_type"):
|
|
memory_type_int = calculate_memory_type(filter_dict["memory_type"])
|
|
memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0)
|
|
if filter_dict.get("storage_type"):
|
|
memories = memories.where(cls.model.storage_type == filter_dict["storage_type"])
|
|
if keywords:
|
|
memories = memories.where(cls.model.name.contains(keywords))
|
|
count = memories.count()
|
|
memories = memories.order_by(cls.model.update_time.desc())
|
|
memories = memories.paginate(page, page_size)
|
|
|
|
return list(memories.dicts()), count
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, tenant_embd_id: int, llm_id: str, tenant_llm_id: int):
|
|
# Deduplicate name within tenant
|
|
memory_name = duplicate_name(
|
|
cls.query,
|
|
name=name,
|
|
tenant_id=tenant_id
|
|
)
|
|
if len(memory_name) > MEMORY_NAME_LIMIT:
|
|
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
|
|
|
|
timestamp = current_timestamp()
|
|
format_time = get_format_time()
|
|
# build create dict
|
|
memory_info = {
|
|
"id": get_uuid(),
|
|
"name": memory_name,
|
|
"memory_type": calculate_memory_type(memory_type),
|
|
"tenant_id": tenant_id,
|
|
"embd_id": embd_id,
|
|
"tenant_embd_id": tenant_embd_id,
|
|
"llm_id": llm_id,
|
|
"tenant_llm_id": tenant_llm_id,
|
|
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
|
|
"create_time": timestamp,
|
|
"create_date": format_time,
|
|
"update_time": timestamp,
|
|
"update_date": format_time,
|
|
}
|
|
obj = cls.model(**memory_info).save(force_insert=True)
|
|
|
|
if not obj:
|
|
return False, "Could not create new memory."
|
|
|
|
db_row = cls.model.select().where(cls.model.id == memory_info["id"]).first()
|
|
|
|
return obj, db_row
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
|
|
if not update_dict:
|
|
return 0
|
|
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
|
|
update_dict["temperature"] = float(update_dict["temperature"])
|
|
if "memory_type" in update_dict and isinstance(update_dict["memory_type"], list):
|
|
update_dict["memory_type"] = calculate_memory_type(update_dict["memory_type"])
|
|
if "name" in update_dict:
|
|
update_dict["name"] = duplicate_name(
|
|
cls.query,
|
|
name=update_dict["name"],
|
|
tenant_id=tenant_id
|
|
)
|
|
update_dict.update({
|
|
"update_time": current_timestamp(),
|
|
"update_date": get_format_time()
|
|
})
|
|
|
|
return cls.model.update(update_dict).where(cls.model.id == memory_id).execute()
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def delete_memory(cls, memory_id: str):
|
|
return cls.delete_by_id(memory_id)
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_null_tenant_embd_id_row(cls):
|
|
fields = [
|
|
cls.model.id,
|
|
cls.model.tenant_id,
|
|
cls.model.embd_id
|
|
]
|
|
objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null())
|
|
return list(objs)
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_null_tenant_llm_id_row(cls):
|
|
fields = [
|
|
cls.model.id,
|
|
cls.model.tenant_id,
|
|
cls.model.llm_id
|
|
]
|
|
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
|
|
return list(objs)
|