Files
ragflow/api/apps/sdk/doc.py
jony376 7f699d1202 Fix: enforce tenant authorization for tenant_rerank_id in retrieval flows (#14782)
### Related issues

Closes #14781 

### What problem does this PR solve?

Some retrieval endpoints accepted caller-supplied `tenant_rerank_id` and
resolved it through `get_model_config_by_id(...)`. That helper loaded
`TenantLLM` rows by global database id and returned decoded model
configuration without checking whether the model belonged to the
authenticated tenant or the dataset owner tenant.

This meant dataset access was validated, but rerank-model selection was
not. A caller who knew or could guess another tenant's
`tenant_rerank_id` could attempt retrieval with a foreign rerank model
config, creating a cross-tenant authorization gap for model usage.

This PR closes that gap by making `tenant_rerank_id` resolution
tenant-aware across the retrieval paths that accept it.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):

### Solution

- Extend `get_model_config_by_id(...)` to accept an optional
`allowed_tenant_ids` set and reject `TenantLLM` rows whose `tenant_id`
is outside that set.
- Pass the allowed tenant scope from retrieval endpoints that accept
`tenant_rerank_id`:
  - `api/apps/sdk/doc.py`
  - `api/apps/sdk/session.py`
  - `api/apps/services/dataset_api_service.py`
- Use the authenticated tenant plus dataset-owner tenant ids already
derived by each retrieval flow as the authorization boundary for rerank
model selection.
- Add focused unit coverage to assert unauthorized `tenant_rerank_id`
values are rejected and that the allowed tenant set is propagated
correctly.

### Testing

- `python -m py_compile` on:
  - `api/db/joint_services/tenant_model_service.py`
  - `api/apps/services/dataset_api_service.py`
  - `api/apps/sdk/doc.py`
  - `api/apps/sdk/session.py`
- Added unit tests in:
-
`test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py`
-
`test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py`

### Notes for reviewers

- This change is intentionally narrow: it affects only the
`tenant_rerank_id` path, not the normal `rerank_id` name-based
resolution path.
- Local lint/syntax checks passed.
- Full pytest execution could not be completed in this environment
because the local test runtime is missing `strenum`, so the route-test
files fail during collection before exercising the updated cases.

---------

Co-authored-by: jony376 <jony376@gmail.com>
2026-05-13 19:53:08 +08:00

580 lines
22 KiB
Python

#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from io import BytesIO
from quart import request, send_file
from api.db.db_models import APIToken, Document, Task
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_request_json, get_result, server_error_response, token_required
from common import settings
from common.constants import LLMType, RetCode, TaskStatus
from common.metadata_utils import convert_conditions, meta_filter
from rag.app.tag import label_question
from rag.nlp import search
from rag.prompts.generator import cross_languages, keyword_extraction
MAXIMUM_OF_UPLOADING_FILES = 256
from api.utils.reference_metadata_utils import (
enrich_chunks_with_document_metadata,
resolve_reference_metadata_preferences,
)
def _resolve_reference_metadata(req: dict, search_config: dict | None = None):
return resolve_reference_metadata_preferences(req, search_config)
def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=None) -> None:
enrich_chunks_with_document_metadata(chunks, metadata_fields)
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821
@token_required
async def download(tenant_id, dataset_id, document_id):
"""
Download a document from a dataset.
---
tags:
- Documents
security:
- ApiKeyAuth: []
produces:
- application/octet-stream
parameters:
- in: path
name: dataset_id
type: string
required: true
description: ID of the dataset.
- in: path
name: document_id
type: string
required: true
description: ID of the document to download.
- in: header
name: Authorization
type: string
required: true
description: Bearer token for authentication.
responses:
200:
description: Document file stream.
schema:
type: file
400:
description: Error message.
schema:
type: object
"""
if not document_id:
return get_error_data_result(message="Specify document_id please.")
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(message=f"You do not own the dataset {dataset_id}.")
doc = DocumentService.query(kb_id=dataset_id, id=document_id)
if not doc:
return get_error_data_result(message=f"The dataset not own the document {document_id}.")
# The process of downloading
doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address
file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location)
if not file_stream:
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
file = BytesIO(file_stream)
# Use send_file with a proper filename and MIME type
return await send_file(
file,
as_attachment=True,
attachment_filename=doc[0].name,
mimetype="application/octet-stream", # Set a default MIME type
)
@manager.route("/documents/<document_id>", methods=["GET"]) # noqa: F821
async def download_doc(document_id):
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message="Authorization is not valid!")
token = token[1]
logging.info("Beta API token lookup attempted for document download")
objs = APIToken.query(beta=token)
if not objs:
logging.warning("Beta API token lookup failed for document download: invalid API key")
return get_error_data_result(message='Authentication error: API key is invalid!"')
if len(objs) > 1:
logging.error("Beta API token lookup is ambiguous for document download: matches=%s", len(objs))
return get_error_data_result(message="Authentication error: API key configuration is ambiguous.")
tenant_id = objs[0].tenant_id
logging.info("Beta API token authorized for document download: tenant_id=%s", tenant_id)
if not document_id:
return get_error_data_result(message="Specify document_id please.")
doc = DocumentService.query(id=document_id)
if not doc:
return get_error_data_result(message=f"The dataset not own the document {document_id}.")
if not KnowledgebaseService.query(id=doc[0].kb_id, tenant_id=tenant_id):
logging.warning(
"cross-tenant access denied for document download: tenant_id=%s kb_id=%s document_id=%s",
tenant_id,
doc[0].kb_id,
document_id,
)
return get_error_data_result(message="You do not have access to this document.")
# The process of downloading
doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address
file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location)
if not file_stream:
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
file = BytesIO(file_stream)
# Use send_file with a proper filename and MIME type
return await send_file(
file,
as_attachment=True,
attachment_filename=doc[0].name,
mimetype="application/octet-stream", # Set a default MIME type
)
DOC_STOP_PARSING_INVALID_STATE_MESSAGE = "Can't stop parsing document that has not started or already completed"
DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE = "DOC_STOP_PARSING_INVALID_STATE"
@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
@token_required
async def parse(tenant_id, dataset_id):
"""
Start parsing documents into chunks.
---
tags:
- Chunks
security:
- ApiKeyAuth: []
parameters:
- in: path
name: dataset_id
type: string
required: true
description: ID of the dataset.
- in: body
name: body
description: Parsing parameters.
required: true
schema:
type: object
properties:
document_ids:
type: array
items:
type: string
description: List of document IDs to parse.
- in: header
name: Authorization
type: string
required: true
description: Bearer token for authentication.
responses:
200:
description: Parsing started successfully.
schema:
type: object
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await get_request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
doc_list = req.get("document_ids")
unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document")
doc_list = unique_doc_ids
not_found = []
success_count = 0
for id in doc_list:
doc = DocumentService.query(id=id, kb_id=dataset_id)
if not doc:
not_found.append(id)
continue
if not doc:
return get_error_data_result(message=f"You don't own the document {id}.")
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
if (
DocumentService.filter_update(
[
Document.id == id,
((Document.run.is_null(True)) | (Document.run != TaskStatus.RUNNING.value)),
],
info,
)
== 0
):
return get_error_data_result("Can't parse document that is currently being processed")
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name, 0)
success_count += 1
if not_found:
return get_result(message=f"Documents not found: {not_found}", code=RetCode.DATA_ERROR)
if duplicate_messages:
if success_count > 0:
return get_result(
message=f"Partially parsed {success_count} documents with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages},
)
else:
return get_error_data_result(message=";".join(duplicate_messages))
return get_result()
@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
@token_required
async def stop_parsing(tenant_id, dataset_id):
"""
Stop parsing documents into chunks.
---
tags:
- Chunks
security:
- ApiKeyAuth: []
parameters:
- in: path
name: dataset_id
type: string
required: true
description: ID of the dataset.
- in: body
name: body
description: Stop parsing parameters.
required: true
schema:
type: object
properties:
document_ids:
type: array
items:
type: string
description: List of document IDs to stop parsing.
- in: header
name: Authorization
type: string
required: true
description: Bearer token for authentication.
responses:
200:
description: Parsing stopped successfully.
schema:
type: object
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await get_request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
doc_list = req.get("document_ids")
unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document")
doc_list = unique_doc_ids
success_count = 0
for id in doc_list:
doc = DocumentService.query(id=id, kb_id=dataset_id)
if not doc:
return get_error_data_result(message=f"You don't own the document {id}.")
if doc[0].run != TaskStatus.RUNNING.value:
return construct_json_result(
code=RetCode.DATA_ERROR,
message=DOC_STOP_PARSING_INVALID_STATE_MESSAGE,
data={"error_code": DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE},
)
# Send cancellation signal via Redis to stop background task
cancel_all_task_of(id)
info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
success_count += 1
if duplicate_messages:
if success_count > 0:
return get_result(
message=f"Partially stopped {success_count} documents with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages},
)
else:
return get_error_data_result(message=";".join(duplicate_messages))
return get_result()
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
@token_required
async def retrieval_test(tenant_id):
"""
Retrieve chunks based on a query.
---
tags:
- Retrieval
security:
- ApiKeyAuth: []
parameters:
- in: body
name: body
description: Retrieval parameters.
required: true
schema:
type: object
properties:
dataset_ids:
type: array
items:
type: string
required: true
description: List of dataset IDs to search in.
question:
type: string
required: true
description: Query string.
document_ids:
type: array
items:
type: string
description: List of document IDs to filter.
similarity_threshold:
type: number
format: float
description: Similarity threshold.
vector_similarity_weight:
type: number
format: float
description: Vector similarity weight.
top_k:
type: integer
description: Maximum number of chunks to return.
highlight:
type: boolean
description: Whether to highlight matched content.
metadata_condition:
type: object
description: metadata filter condition.
- in: header
name: Authorization
type: string
required: true
description: Bearer token for authentication.
responses:
200:
description: Retrieval results.
schema:
type: object
properties:
chunks:
type: array
items:
type: object
properties:
id:
type: string
description: Chunk ID.
content:
type: string
description: Chunk content.
document_id:
type: string
description: ID of the document.
dataset_id:
type: string
description: ID of the dataset.
similarity:
type: number
format: float
description: Similarity score.
"""
req = await get_request_json()
if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.")
kb_ids = req["dataset_ids"]
if not isinstance(kb_ids, list):
return get_error_data_result("`dataset_ids` should be a list")
for id in kb_ids:
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
return get_error_data_result(f"You don't own the dataset {id}.")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison
if len(embd_nms) != 1:
return get_result(
message='Datasets use different embedding models."',
code=RetCode.DATA_ERROR,
)
if "question" not in req:
return get_error_data_result("`question` is required.")
page = int(req.get("page", 1))
size = int(req.get("page_size", 30))
question = req["question"]
# Trim whitespace and validate question
if isinstance(question, str):
question = question.strip()
# Return empty result if question is empty or whitespace-only
if not question:
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
doc_ids = req.get("document_ids", [])
use_kg = req.get("use_kg", False)
toc_enhance = req.get("toc_enhance", False)
langs = req.get("cross_languages", [])
if not isinstance(doc_ids, list):
return get_error_data_result("`documents` should be a list")
if doc_ids:
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
for doc_id in doc_ids:
if doc_id not in doc_ids_list:
return get_error_data_result(f"The datasets don't own the document {doc_id}")
if not doc_ids:
metadata_condition = req.get("metadata_condition")
if metadata_condition:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
# If metadata_condition has conditions but no docs match, return empty result
if not doc_ids and metadata_condition.get("conditions"):
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
if metadata_condition and not doc_ids:
doc_ids = ["-999"]
else:
# If doc_ids is None all documents of the datasets are used
doc_ids = None
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
if top <= 0:
return get_error_data_result("`top_k` must be greater than 0")
highlight_val = req.get("highlight", None)
if highlight_val is None:
highlight = False
elif isinstance(highlight_val, bool):
highlight = highlight_val
elif isinstance(highlight_val, str):
if highlight_val.lower() in ["true", "false"]:
highlight = highlight_val.lower() == "true"
else:
return get_error_data_result("`highlight` should be a boolean")
else:
return get_error_data_result("`highlight` should be a boolean")
include_metadata, metadata_fields = _resolve_reference_metadata(req)
try:
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_error_data_result(message="Dataset not found!")
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
rerank_mdl = None
if req.get("tenant_rerank_id"):
allowed_rerank_tenant_ids = {tenant_id, *[dataset.tenant_id for dataset in kbs]}
rerank_model_config = get_model_config_by_id(
req["tenant_rerank_id"],
allowed_tenant_ids=allowed_rerank_tenant_ids,
requester_tenant_id=tenant_id,
)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
elif req.get("rerank_id"):
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
if langs:
question = await cross_languages(kb.tenant_id, None, question, langs)
if req.get("keyword", False):
chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, chat_model_config)
question += await keyword_extraction(chat_mdl, question)
ranks = await settings.retriever.retrieval(
question,
embd_mdl,
tenant_ids,
kb_ids,
page,
size,
similarity_threshold,
vector_similarity_weight,
top,
doc_ids,
rerank_mdl=rerank_mdl,
highlight=highlight,
rank_feature=label_question(question, kbs),
)
if toc_enhance:
chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, chat_model_config)
cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
if cks:
ranks["chunks"] = cks
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
if use_kg:
chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, chat_model_config))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
for c in ranks["chunks"]:
c.pop("vector", None)
if include_metadata:
logging.info(
"sdk.retrieval reference_metadata enabled dataset_ids=%s fields=%s chunks=%s",
kb_ids,
sorted(metadata_fields) if metadata_fields else None,
len(ranks["chunks"]),
)
enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields)
##rename keys
renamed_chunks = []
for chunk in ranks["chunks"]:
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
"question_kwd": "questions",
"docnm_kwd": "document_keyword",
"kb_id": "dataset_id",
}
rename_chunk = {}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
rename_chunk[new_key] = value
renamed_chunks.append(rename_chunk)
ranks["chunks"] = renamed_chunks
return get_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_result(
message="No chunk found! Check the chunk status please!",
code=RetCode.DATA_ERROR,
)
return server_error_response(e)