# # Copyright 2026 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from io import BytesIO from quart import request, send_file from api.db.db_models import APIToken, Document, Task from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks from api.db.services.tenant_llm_service import TenantLLMService from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_request_json, get_result, server_error_response, token_required from common import settings from common.constants import LLMType, RetCode, TaskStatus from common.metadata_utils import convert_conditions, meta_filter from rag.app.tag import label_question from rag.nlp import search from rag.prompts.generator import cross_languages, keyword_extraction MAXIMUM_OF_UPLOADING_FILES = 256 @manager.route("/datasets//documents/", methods=["GET"]) # noqa: F821 @token_required async def download(tenant_id, dataset_id, document_id): """ Download a document from a dataset. --- tags: - Documents security: - ApiKeyAuth: [] produces: - application/octet-stream parameters: - in: path name: dataset_id type: string required: true description: ID of the dataset. - in: path name: document_id type: string required: true description: ID of the document to download. - in: header name: Authorization type: string required: true description: Bearer token for authentication. responses: 200: description: Document file stream. schema: type: file 400: description: Error message. schema: type: object """ if not document_id: return get_error_data_result(message="Specify document_id please.") if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(message=f"You do not own the dataset {dataset_id}.") doc = DocumentService.query(kb_id=dataset_id, id=document_id) if not doc: return get_error_data_result(message=f"The dataset not own the document {document_id}.") # The process of downloading doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location) if not file_stream: return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR) file = BytesIO(file_stream) # Use send_file with a proper filename and MIME type return await send_file( file, as_attachment=True, attachment_filename=doc[0].name, mimetype="application/octet-stream", # Set a default MIME type ) @manager.route("/documents/", methods=["GET"]) # noqa: F821 async def download_doc(document_id): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message="Authorization is not valid!") token = token[1] objs = APIToken.query(beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') if not document_id: return get_error_data_result(message="Specify document_id please.") doc = DocumentService.query(id=document_id) if not doc: return get_error_data_result(message=f"The dataset not own the document {document_id}.") # The process of downloading doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location) if not file_stream: return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR) file = BytesIO(file_stream) # Use send_file with a proper filename and MIME type return await send_file( file, as_attachment=True, attachment_filename=doc[0].name, mimetype="application/octet-stream", # Set a default MIME type ) DOC_STOP_PARSING_INVALID_STATE_MESSAGE = "Can't stop parsing document that has not started or already completed" DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE = "DOC_STOP_PARSING_INVALID_STATE" @manager.route("/datasets//chunks", methods=["POST"]) # noqa: F821 @token_required async def parse(tenant_id, dataset_id): """ Start parsing documents into chunks. --- tags: - Chunks security: - ApiKeyAuth: [] parameters: - in: path name: dataset_id type: string required: true description: ID of the dataset. - in: body name: body description: Parsing parameters. required: true schema: type: object properties: document_ids: type: array items: type: string description: List of document IDs to parse. - in: header name: Authorization type: string required: true description: Bearer token for authentication. responses: 200: description: Parsing started successfully. schema: type: object """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") req = await get_request_json() if not req.get("document_ids"): return get_error_data_result("`document_ids` is required") doc_list = req.get("document_ids") unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document") doc_list = unique_doc_ids not_found = [] success_count = 0 for id in doc_list: doc = DocumentService.query(id=id, kb_id=dataset_id) if not doc: not_found.append(id) continue if not doc: return get_error_data_result(message=f"You don't own the document {id}.") info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} if ( DocumentService.filter_update( [ Document.id == id, ((Document.run.is_null(True)) | (Document.run != TaskStatus.RUNNING.value)), ], info, ) == 0 ): return get_error_data_result("Can't parse document that is currently being processed") settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) TaskService.filter_delete([Task.doc_id == id]) e, doc = DocumentService.get_by_id(id) doc = doc.to_dict() doc["tenant_id"] = tenant_id bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) queue_tasks(doc, bucket, name, 0) success_count += 1 if not_found: return get_result(message=f"Documents not found: {not_found}", code=RetCode.DATA_ERROR) if duplicate_messages: if success_count > 0: return get_result( message=f"Partially parsed {success_count} documents with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages}, ) else: return get_error_data_result(message=";".join(duplicate_messages)) return get_result() @manager.route("/datasets//chunks", methods=["DELETE"]) # noqa: F821 @token_required async def stop_parsing(tenant_id, dataset_id): """ Stop parsing documents into chunks. --- tags: - Chunks security: - ApiKeyAuth: [] parameters: - in: path name: dataset_id type: string required: true description: ID of the dataset. - in: body name: body description: Stop parsing parameters. required: true schema: type: object properties: document_ids: type: array items: type: string description: List of document IDs to stop parsing. - in: header name: Authorization type: string required: true description: Bearer token for authentication. responses: 200: description: Parsing stopped successfully. schema: type: object """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") req = await get_request_json() if not req.get("document_ids"): return get_error_data_result("`document_ids` is required") doc_list = req.get("document_ids") unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document") doc_list = unique_doc_ids success_count = 0 for id in doc_list: doc = DocumentService.query(id=id, kb_id=dataset_id) if not doc: return get_error_data_result(message=f"You don't own the document {id}.") if doc[0].run != TaskStatus.RUNNING.value: return construct_json_result( code=RetCode.DATA_ERROR, message=DOC_STOP_PARSING_INVALID_STATE_MESSAGE, data={"error_code": DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE}, ) # Send cancellation signal via Redis to stop background task cancel_all_task_of(id) info = {"run": "2", "progress": 0, "chunk_num": 0} DocumentService.update_by_id(id, info) settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) success_count += 1 if duplicate_messages: if success_count > 0: return get_result( message=f"Partially stopped {success_count} documents with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages}, ) else: return get_error_data_result(message=";".join(duplicate_messages)) return get_result() @manager.route("/retrieval", methods=["POST"]) # noqa: F821 @token_required async def retrieval_test(tenant_id): """ Retrieve chunks based on a query. --- tags: - Retrieval security: - ApiKeyAuth: [] parameters: - in: body name: body description: Retrieval parameters. required: true schema: type: object properties: dataset_ids: type: array items: type: string required: true description: List of dataset IDs to search in. question: type: string required: true description: Query string. document_ids: type: array items: type: string description: List of document IDs to filter. similarity_threshold: type: number format: float description: Similarity threshold. vector_similarity_weight: type: number format: float description: Vector similarity weight. top_k: type: integer description: Maximum number of chunks to return. highlight: type: boolean description: Whether to highlight matched content. metadata_condition: type: object description: metadata filter condition. - in: header name: Authorization type: string required: true description: Bearer token for authentication. responses: 200: description: Retrieval results. schema: type: object properties: chunks: type: array items: type: object properties: id: type: string description: Chunk ID. content: type: string description: Chunk content. document_id: type: string description: ID of the document. dataset_id: type: string description: ID of the dataset. similarity: type: number format: float description: Similarity score. """ req = await get_request_json() if not req.get("dataset_ids"): return get_error_data_result("`dataset_ids` is required.") kb_ids = req["dataset_ids"] if not isinstance(kb_ids, list): return get_error_data_result("`dataset_ids` should be a list") for id in kb_ids: if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id): return get_error_data_result(f"You don't own the dataset {id}.") kbs = KnowledgebaseService.get_by_ids(kb_ids) embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison if len(embd_nms) != 1: return get_result( message='Datasets use different embedding models."', code=RetCode.DATA_ERROR, ) if "question" not in req: return get_error_data_result("`question` is required.") page = int(req.get("page", 1)) size = int(req.get("page_size", 30)) question = req["question"] # Trim whitespace and validate question if isinstance(question, str): question = question.strip() # Return empty result if question is empty or whitespace-only if not question: return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}}) doc_ids = req.get("document_ids", []) use_kg = req.get("use_kg", False) toc_enhance = req.get("toc_enhance", False) langs = req.get("cross_languages", []) if not isinstance(doc_ids, list): return get_error_data_result("`documents` should be a list") if doc_ids: doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids) for doc_id in doc_ids: if doc_id not in doc_ids_list: return get_error_data_result(f"The datasets don't own the document {doc_id}") if not doc_ids: metadata_condition = req.get("metadata_condition") if metadata_condition: metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")) # If metadata_condition has conditions but no docs match, return empty result if not doc_ids and metadata_condition.get("conditions"): return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}}) if metadata_condition and not doc_ids: doc_ids = ["-999"] else: # If doc_ids is None all documents of the datasets are used doc_ids = None similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) highlight_val = req.get("highlight", None) if highlight_val is None: highlight = False elif isinstance(highlight_val, bool): highlight = highlight_val elif isinstance(highlight_val, str): if highlight_val.lower() in ["true", "false"]: highlight = highlight_val.lower() == "true" else: return get_error_data_result("`highlight` should be a boolean") else: return get_error_data_result("`highlight` should be a boolean") try: tenant_ids = list(set([kb.tenant_id for kb in kbs])) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: return get_error_data_result(message="Dataset not found!") if kb.tenant_embd_id: embd_model_config = get_model_config_by_id(kb.tenant_embd_id) else: embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) rerank_mdl = None if req.get("tenant_rerank_id"): rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif req.get("rerank_id"): rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"]) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) if langs: question = await cross_languages(kb.tenant_id, None, question, langs) if req.get("keyword", False): chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, chat_model_config) question += await keyword_extraction(chat_mdl, question) ranks = await settings.retriever.retrieval( question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=highlight, rank_feature=label_question(question, kbs), ) if toc_enhance: chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, chat_model_config) cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size) if cks: ranks["chunks"] = cks ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) if use_kg: chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, chat_model_config)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) for c in ranks["chunks"]: c.pop("vector", None) ##rename keys renamed_chunks = [] for chunk in ranks["chunks"]: key_mapping = { "chunk_id": "id", "content_with_weight": "content", "doc_id": "document_id", "important_kwd": "important_keywords", "question_kwd": "questions", "docnm_kwd": "document_keyword", "kb_id": "dataset_id", } rename_chunk = {} for key, value in chunk.items(): new_key = key_mapping.get(key, key) rename_chunk[new_key] = value renamed_chunks.append(rename_chunk) ranks["chunks"] = renamed_chunks return get_result(data=ranks) except Exception as e: if str(e).find("not_found") > 0: return get_result( message="No chunk found! Check the chunk status please!", code=RetCode.DATA_ERROR, ) return server_error_response(e)