From 08f83ff3313344d45f57d3d2c20bbf7316f7dd79 Mon Sep 17 00:00:00 2001 From: Heyang Wang Date: Tue, 10 Mar 2026 18:05:45 +0800 Subject: [PATCH] Feat: Support get aggregated parsing status to dataset via the API (#13481) ### What problem does this PR solve? Support getting aggregated parsing status to dataset via the API Issue: #12810 ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: heyang.why --- api/apps/sdk/dataset.py | 136 +++--- api/db/services/document_service.py | 407 ++++++++---------- api/utils/validation_utils.py | 7 +- docs/references/http_api_reference.md | 61 ++- docs/references/python_api_reference.md | 20 +- example/http/dataset_example.sh | 6 + ...est_document_service_get_parsing_status.py | 326 ++++++++++++++ 7 files changed, 654 insertions(+), 309 deletions(-) create mode 100644 test/unit_test/api/db/services/test_document_service_get_parsing_status.py diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index caa75ec02..58f0442b6 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -138,12 +138,7 @@ async def create(tenant_id): parser_cfg["metadata"] = fields parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) req["parser_config"] = parser_cfg - e, req = KnowledgebaseService.create_with_name( - name = req.pop("name", None), - tenant_id = tenant_id, - parser_id = req.pop("parser_id", None), - **req - ) + e, req = KnowledgebaseService.create_with_name(name=req.pop("name", None), tenant_id=tenant_id, parser_id=req.pop("parser_id", None), **req) if not e: return req @@ -159,19 +154,19 @@ async def create(tenant_id): if not ok: return err - try: - if not KnowledgebaseService.save(**req): - return get_error_data_result() - ok, k = KnowledgebaseService.get_by_id(req["id"]) - if not ok: - return get_error_data_result(message="Dataset created failed") - response_data = remap_dictionary_keys(k.to_dict()) - return get_result(data=response_data) + if not KnowledgebaseService.save(**req): + return get_error_data_result() + ok, k = KnowledgebaseService.get_by_id(req["id"]) + if not ok: + return get_error_data_result(message="Dataset created failed") + response_data = remap_dictionary_keys(k.to_dict()) + return get_result(data=response_data) except Exception as e: logging.exception(e) return get_error_data_result(message="Database operation failed") + @manager.route("/datasets", methods=["DELETE"]) # noqa: F821 @token_required async def delete(tenant_id): @@ -227,8 +222,7 @@ async def delete(tenant_id): continue kb_id_instance_pairs.append((kb_id, kb)) if len(error_kb_ids) > 0: - return get_error_permission_result( - message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") + return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") errors = [] success_count = 0 @@ -245,12 +239,12 @@ async def delete(tenant_id): ] ) File2DocumentService.delete_by_document_id(doc.id) - FileService.filter_delete( - [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) # Drop index for this dataset try: from rag.nlp import search + idxnm = search.index_name(kb.tenant_id) settings.docStoreConn.delete_idx(idxnm, kb_id) except Exception as e: @@ -352,8 +346,7 @@ async def update(tenant_id, dataset_id): try: kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) if kb is None: - return get_error_permission_result( - message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") # Map auto_metadata_config into parser_config if present auto_meta = req.pop("auto_metadata_config", None) @@ -384,8 +377,7 @@ async def update(tenant_id, dataset_id): del req["parser_config"] if "name" in req and req["name"].lower() != kb.name.lower(): - exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, - status=StatusEnum.VALID.value) + exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value) if exists: return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") @@ -393,8 +385,7 @@ async def update(tenant_id, dataset_id): if not req["embd_id"]: req["embd_id"] = kb.embd_id if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: - return get_error_data_result( - message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") + return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") ok, err = verify_embedding_availability(req["embd_id"], tenant_id) if not ok: return err @@ -404,12 +395,10 @@ async def update(tenant_id, dataset_id): return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") if req["pagerank"] > 0: - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, - search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, - search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) if not KnowledgebaseService.update_by_id(kb.id, req): return get_error_data_result(message="Update dataset error.(Database error)") @@ -470,6 +459,15 @@ def list_datasets(tenant_id): required: false default: true description: Order in descending. + - in: query + name: include_parsing_status + type: boolean + required: false + default: false + description: | + Whether to include document parsing status counts in the response. + When true, each dataset object will include: unstart_count, running_count, + cancel_count, done_count, and fail_count. - in: header name: Authorization type: string @@ -487,17 +485,18 @@ def list_datasets(tenant_id): if err is not None: return get_error_argument_result(err) + include_parsing_status = args.get("include_parsing_status", False) + try: kb_id = request.args.get("id") name = args.get("name") + # check whether user has permission for the dataset with specified id if kb_id: - kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id) - - if not kbs: + if not KnowledgebaseService.get_kb_by_id(kb_id, tenant_id): return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'") + # check whether user has permission for the dataset with specified name if name: - kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) - if not kbs: + if not KnowledgebaseService.get_kb_by_name(name, tenant_id): return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'") tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) @@ -512,9 +511,17 @@ def list_datasets(tenant_id): name, ) + parsing_status_map = {} + if include_parsing_status and kbs: + kb_ids = [kb["id"] for kb in kbs] + parsing_status_map = DocumentService.get_parsing_status_by_kb_ids(kb_ids) + response_data_list = [] for kb in kbs: - response_data_list.append(remap_dictionary_keys(kb)) + data = remap_dictionary_keys(kb) + if include_parsing_status: + data.update(parsing_status_map.get(kb["id"], {})) + response_data_list.append(data) return get_result(data=response_data_list, total=total) except OperationalError as e: logging.exception(e) @@ -530,9 +537,7 @@ def get_auto_metadata(tenant_id, dataset_id): try: kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) if kb is None: - return get_error_permission_result( - message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" - ) + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") parser_cfg = kb.parser_config or {} metadata = parser_cfg.get("metadata") or [] @@ -570,9 +575,7 @@ async def update_auto_metadata(tenant_id, dataset_id): try: kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) if kb is None: - return get_error_permission_result( - message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" - ) + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") parser_cfg = kb.parser_config or {} fields = [] @@ -598,20 +601,13 @@ async def update_auto_metadata(tenant_id, dataset_id): return get_error_data_result(message="Database operation failed") -@manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821 +@manager.route("/datasets//knowledge_graph", methods=["GET"]) # noqa: F821 @token_required async def knowledge_graph(tenant_id, dataset_id): if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) _, kb = KnowledgebaseService.get_by_id(dataset_id) - req = { - "kb_id": [dataset_id], - "knowledge_graph_kwd": ["graph"] - } + req = {"kb_id": [dataset_id], "knowledge_graph_kwd": ["graph"]} obj = {"graph": {}, "mind_map": {}} if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): @@ -633,39 +629,29 @@ async def knowledge_graph(tenant_id, dataset_id): obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256] if "edges" in obj["graph"]: node_id_set = {o["id"] for o in obj["graph"]["nodes"]} - filtered_edges = [o for o in obj["graph"]["edges"] if - o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] + filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] return get_result(data=obj) -@manager.route('/datasets//knowledge_graph', methods=['DELETE']) # noqa: F821 +@manager.route("/datasets//knowledge_graph", methods=["DELETE"]) # noqa: F821 @token_required def delete_knowledge_graph(tenant_id, dataset_id): if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) _, kb = KnowledgebaseService.get_by_id(dataset_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, - search.index_name(kb.tenant_id), dataset_id) + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id) return get_result(data=True) @manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 @token_required -def run_graphrag(tenant_id,dataset_id): +def run_graphrag(tenant_id, dataset_id): if not dataset_id: return get_error_data_result(message='Lack of "Dataset ID"') if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) ok, kb = KnowledgebaseService.get_by_id(dataset_id) if not ok: @@ -707,15 +693,11 @@ def run_graphrag(tenant_id,dataset_id): @manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 @token_required -def trace_graphrag(tenant_id,dataset_id): +def trace_graphrag(tenant_id, dataset_id): if not dataset_id: return get_error_data_result(message='Lack of "Dataset ID"') if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) ok, kb = KnowledgebaseService.get_by_id(dataset_id) if not ok: @@ -734,15 +716,11 @@ def trace_graphrag(tenant_id,dataset_id): @manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 @token_required -def run_raptor(tenant_id,dataset_id): +def run_raptor(tenant_id, dataset_id): if not dataset_id: return get_error_data_result(message='Lack of "Dataset ID"') if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) ok, kb = KnowledgebaseService.get_by_id(dataset_id) if not ok: @@ -784,7 +762,7 @@ def run_raptor(tenant_id,dataset_id): @manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 @token_required -def trace_raptor(tenant_id,dataset_id): +def trace_raptor(tenant_id, dataset_id): if not dataset_id: return get_error_data_result(message='Lack of "Dataset ID"') diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index f5b2d9d51..fa4fc27ec 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -28,8 +28,7 @@ from peewee import fn, Case, JOIN from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES, FileType, UserTenantRole, CanvasCategory -from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \ - User +from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, User from api.db.db_utils import bulk_insert_into_db from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -78,24 +77,21 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() - def get_list(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords, id, name, suffix=None, run = None, doc_ids=None): + def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name, suffix=None, run=None, doc_ids=None): fields = cls.get_cls_model_fields() - docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\ - .join(File, on = (File.id == File2Document.file_id))\ - .join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\ + docs = ( + cls.model.select(*[*fields, UserCanvas.title]) + .join(File2Document, on=(File2Document.document_id == cls.model.id)) + .join(File, on=(File.id == File2Document.file_id)) + .join(UserCanvas, on=((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER) .where(cls.model.kb_id == kb_id) + ) if id: - docs = docs.where( - cls.model.id == id) + docs = docs.where(cls.model.id == id) if name: - docs = docs.where( - cls.model.name == name - ) + docs = docs.where(cls.model.name == name) if keywords: - docs = docs.where( - fn.LOWER(cls.model.name).contains(keywords.lower()) - ) + docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower())) if doc_ids: docs = docs.where(cls.model.id.in_(doc_ids)) if suffix: @@ -120,6 +116,7 @@ class DocumentService(CommonService): @DB.connection_context() def check_doc_health(cls, tenant_id: str, filename): import os + MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0)) if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id): raise RuntimeError("Exceed the maximum file number of a free user!") @@ -211,13 +208,14 @@ class DocumentService(CommonService): """ fields = cls.get_cls_model_fields() if keywords: - query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where( - (cls.model.kb_id == kb_id), - (fn.LOWER(cls.model.name).contains(keywords.lower())) + query = ( + cls.model.select(*fields) + .join(File2Document, on=(File2Document.document_id == cls.model.id)) + .join(File, on=(File.id == File2Document.file_id)) + .where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower()))) ) else: - query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id) - + query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id) if run_status: query = query.where(cls.model.run.in_(run_status)) @@ -272,14 +270,60 @@ class DocumentService(CommonService): "metadata": metadata_counter, }, total + @classmethod + @DB.connection_context() + def get_parsing_status_by_kb_ids(cls, kb_ids: list[str]) -> dict[str, dict[str, int]]: + """Return aggregated document parsing status counts grouped by dataset (kb_id). + + For each kb_id, counts documents in each run-status bucket: + - unstart_count (run == "0") + - running_count (run == "1") + - cancel_count (run == "2") + - done_count (run == "3") + - fail_count (run == "4") + + Returns a dict keyed by kb_id, e.g. + {"kb-abc": {"unstart_count": 10, "running_count": 2, ...}, ...} + """ + if not kb_ids: + return {} + + status_field_map = { + TaskStatus.UNSTART.value: "unstart_count", + TaskStatus.RUNNING.value: "running_count", + TaskStatus.CANCEL.value: "cancel_count", + TaskStatus.DONE.value: "done_count", + TaskStatus.FAIL.value: "fail_count", + } + + empty_status = {v: 0 for v in status_field_map.values()} + result: dict[str, dict[str, int]] = {kb_id: dict(empty_status) for kb_id in kb_ids} + + rows = ( + cls.model.select( + cls.model.kb_id, + cls.model.run, + fn.COUNT(cls.model.id).alias("cnt"), + ) + .where(cls.model.kb_id.in_(kb_ids)) + .group_by(cls.model.kb_id, cls.model.run) + .dicts() + ) + + for row in rows: + kb_id = row["kb_id"] + run_val = str(row["run"]) + field_name = status_field_map.get(run_val) + if field_name and kb_id in result: + result[kb_id][field_name] = int(row["cnt"]) + + return result + @classmethod @DB.connection_context() def count_by_kb_id(cls, kb_id, keywords, run_status, types): if keywords: - docs = cls.model.select().where( - (cls.model.kb_id == kb_id), - (fn.LOWER(cls.model.name).contains(keywords.lower())) - ) + docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower()))) else: docs = cls.model.select().where(cls.model.kb_id == kb_id) @@ -295,9 +339,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]): - query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where( - cls.model.kb_id == kb_id - ) + query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(cls.model.kb_id == kb_id) if keywords: query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower())) @@ -329,12 +371,8 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_all_docs_by_creator_id(cls, creator_id): - fields = [ - cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id - ] - docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where( - cls.model.created_by == creator_id - ) + fields = [cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id] + docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.created_by == creator_id) docs.order_by(cls.model.create_time.asc()) # maybe cause slow query by deep paginate, optimize later offset, limit = 0, 100 @@ -361,6 +399,7 @@ class DocumentService(CommonService): @DB.connection_context() def remove_document(cls, doc, tenant_id): from api.db.services.task_service import TaskService, cancel_all_task_of + if not cls.delete_document_and_update_kb_counts(doc.id): return True @@ -406,17 +445,22 @@ class DocumentService(CommonService): # Cleanup knowledge graph references (non-critical, log and continue) try: graph_source = settings.docStoreConn.get_fields( - settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] + settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), + ["source_id"], ) if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, - {"remove": {"source_id": doc.id}}, - search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, - {"removed_kwd": "Y"}, - search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, - search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.update( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, + {"remove": {"source_id": doc.id}}, + search.index_name(tenant_id), + doc.kb_id, + ) + settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.delete( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, + search.index_name(tenant_id), + doc.kb_id, + ) except Exception as e: logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}") @@ -428,9 +472,7 @@ class DocumentService(CommonService): page = 0 page_size = 1000 while True: - chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), - page * page_size, page_size, search.index_name(tenant_id), - [doc.kb_id]) + chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id]) chunk_ids = settings.docStoreConn.get_doc_ids(chunks) if not chunk_ids: break @@ -455,71 +497,61 @@ class DocumentService(CommonService): Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, - cls.model.update_time] - docs = cls.model.select(*fields) \ - .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \ + cls.model.update_time, + ] + docs = ( + cls.model.select(*fields) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .where( - cls.model.status == StatusEnum.VALID.value, - ~(cls.model.type == FileType.VIRTUAL.value), - cls.model.progress == 0, - cls.model.update_time >= current_timestamp() - 1000 * 600, - cls.model.run == TaskStatus.RUNNING.value) \ + cls.model.status == StatusEnum.VALID.value, + ~(cls.model.type == FileType.VIRTUAL.value), + cls.model.progress == 0, + cls.model.update_time >= current_timestamp() - 1000 * 600, + cls.model.run == TaskStatus.RUNNING.value, + ) .order_by(cls.model.update_time.asc()) + ) return list(docs.dicts()) @classmethod @DB.connection_context() def get_unfinished_docs(cls): - fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, - cls.model.run, cls.model.parser_id] - unfinished_task_query = Task.select(Task.doc_id).where( - (Task.progress >= 0) & (Task.progress < 1) - ) + fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id] + unfinished_task_query = Task.select(Task.doc_id).where((Task.progress >= 0) & (Task.progress < 1)) - docs = cls.model.select(*fields) \ - .where( + docs = cls.model.select(*fields).where( cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)), - (((cls.model.progress < 1) & (cls.model.progress > 0)) | - (cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap + (((cls.model.progress < 1) & (cls.model.progress > 0)) | (cls.model.id.in_(unfinished_task_query))), + ) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap return list(docs.dicts()) @classmethod @DB.connection_context() def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration): - num = cls.model.update(token_num=cls.model.token_num + token_num, - chunk_num=cls.model.chunk_num + chunk_num, - process_duration=cls.model.process_duration + duration).where( - cls.model.id == doc_id).execute() + num = ( + cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duration=cls.model.process_duration + duration) + .where(cls.model.id == doc_id) + .execute() + ) if num == 0: logging.warning("Document not found which is supposed to be there") - num = Knowledgebase.update( - token_num=Knowledgebase.token_num + - token_num, - chunk_num=Knowledgebase.chunk_num + - chunk_num).where( - Knowledgebase.id == kb_id).execute() + num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute() return num @classmethod @DB.connection_context() def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration): - num = cls.model.update(token_num=cls.model.token_num - token_num, - chunk_num=cls.model.chunk_num - chunk_num, - process_duration=cls.model.process_duration + duration).where( - cls.model.id == doc_id).execute() + num = ( + cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duration=cls.model.process_duration + duration) + .where(cls.model.id == doc_id) + .execute() + ) if num == 0: - raise LookupError( - "Document not found which is supposed to be there") - num = Knowledgebase.update( - token_num=Knowledgebase.token_num - - token_num, - chunk_num=Knowledgebase.chunk_num - - chunk_num - ).where( - Knowledgebase.id == kb_id).execute() + raise LookupError("Document not found which is supposed to be there") + num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute() return num @classmethod @@ -551,17 +583,13 @@ class DocumentService(CommonService): doc = cls.model.get_by_id(doc_id) assert doc, "Can't fine document in database." - num = Knowledgebase.update( - token_num=Knowledgebase.token_num - - doc.token_num, - chunk_num=Knowledgebase.chunk_num - - doc.chunk_num, - doc_num=Knowledgebase.doc_num - 1 - ).where( - Knowledgebase.id == doc.kb_id).execute() + num = ( + Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1) + .where(Knowledgebase.id == doc.kb_id) + .execute() + ) return num - @classmethod @DB.connection_context() def clear_chunk_num_when_rerun(cls, doc_id): @@ -578,15 +606,10 @@ class DocumentService(CommonService): ) return num - @classmethod @DB.connection_context() def get_tenant_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.tenant_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return None @@ -604,11 +627,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_tenant_id_by_name(cls, name): - docs = cls.model.select( - Knowledgebase.tenant_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return None @@ -617,12 +636,13 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def accessible(cls, doc_id, user_id): - docs = cls.model.select( - cls.model.id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) - ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1) + docs = ( + cls.model.select(cls.model.id) + .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)) + .join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)) + .where(cls.model.id == doc_id, UserTenant.user_id == user_id) + .paginate(0, 1) + ) docs = docs.dicts() if not docs: return False @@ -631,18 +651,13 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def accessible4deletion(cls, doc_id, user_id): - docs = cls.model.select(cls.model.id - ).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) - ).join( - UserTenant, on=( - (UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id)) - ).where( - cls.model.id == doc_id, - UserTenant.status == StatusEnum.VALID.value, - ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER)) - ).paginate(0, 1) + docs = ( + cls.model.select(cls.model.id) + .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)) + .join(UserTenant, on=((UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))) + .where(cls.model.id == doc_id, UserTenant.status == StatusEnum.VALID.value, ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))) + .paginate(0, 1) + ) docs = docs.dicts() if not docs: return False @@ -651,11 +666,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_embd_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.embd_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: return None @@ -664,11 +675,9 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_tenant_embd_id(cls, doc_id): - docs = cls.model.select( - Knowledgebase.tenant_embd_id).join( - Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( - cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + docs = ( + cls.model.select(Knowledgebase.tenant_embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) + ) docs = docs.dicts() if not docs: return None @@ -705,8 +714,7 @@ class DocumentService(CommonService): @DB.connection_context() def get_doc_id_by_doc_name(cls, doc_name): fields = [cls.model.id] - doc_id = cls.model.select(*fields) \ - .where(cls.model.name == doc_name) + doc_id = cls.model.select(*fields).where(cls.model.name == doc_name) doc_id = doc_id.dicts() if not doc_id: return None @@ -725,8 +733,7 @@ class DocumentService(CommonService): @DB.connection_context() def get_thumbnails(cls, docids): fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail] - return list(cls.model.select( - *fields).where(cls.model.id.in_(docids)).dicts()) + return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts()) @classmethod @DB.connection_context() @@ -755,9 +762,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_doc_count(cls, tenant_id): - docs = cls.model.select(cls.model.id).join(Knowledgebase, - on=(Knowledgebase.id == cls.model.kb_id)).where( - Knowledgebase.tenant_id == tenant_id) + docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id) return len(docs) @classmethod @@ -768,7 +773,7 @@ class DocumentService(CommonService): "process_begin_at": get_format_time(), } if not keep_progress: - info["progress"] = random.random() * 1 / 100. + info["progress"] = random.random() * 1 / 100.0 info["run"] = TaskStatus.RUNNING.value # keep the doc in DONE state when keep_progress=True for GraphRAG, RAPTOR and Mindmap tasks @@ -781,19 +786,17 @@ class DocumentService(CommonService): cls._sync_progress(docs) - @classmethod @DB.connection_context() - def update_progress_immediately(cls, docs:list[dict]): + def update_progress_immediately(cls, docs: list[dict]): if not docs: return cls._sync_progress(docs) - @classmethod @DB.connection_context() - def _sync_progress(cls, docs:list[dict]): + def _sync_progress(cls, docs: list[dict]): from api.db.services.task_service import TaskService for d in docs: @@ -841,27 +844,18 @@ class DocumentService(CommonService): # fallback cls.update_by_id(d["id"], {"process_begin_at": begin_at}) - info = { - "process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), - "run": status} + info = {"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), "run": status} if prg != 0 and not freeze_progress: info["progress"] = prg if msg: info["progress_msg"] = msg if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"): - info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) + info["progress_msg"] += "\n%d tasks are ahead in the queue..." % get_queue_length(priority) else: - info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) + info["progress_msg"] = "%d tasks are ahead in the queue..." % get_queue_length(priority) info["update_time"] = current_timestamp() info["update_date"] = get_format_time() - ( - cls.model.update(info) - .where( - (cls.model.id == d["id"]) - & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)) - ) - .execute() - ) + (cls.model.update(info).where((cls.model.id == d["id"]) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))).execute()) except Exception as e: if str(e).find("'0'") < 0: logging.exception("fetch task exception") @@ -875,7 +869,7 @@ class DocumentService(CommonService): @DB.connection_context() def get_all_kb_doc_count(cls): result = {} - rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id) + rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias("count")).group_by(cls.model.kb_id) for row in rows: result[row.kb_id] = row.count return result @@ -890,33 +884,19 @@ class DocumentService(CommonService): pass return False - @classmethod @DB.connection_context() def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]: # cancelled: run == "2" - cancelled = ( - cls.model.select(fn.COUNT(1)) - .where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)) - .scalar() - ) - downloaded = ( - cls.model.select(fn.COUNT(1)) - .where( - cls.model.kb_id == kb_id, - cls.model.source_type != "local" - ) - .scalar() - ) + cancelled = cls.model.select(fn.COUNT(1)).where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)).scalar() + downloaded = cls.model.select(fn.COUNT(1)).where(cls.model.kb_id == kb_id, cls.model.source_type != "local").scalar() row = ( cls.model.select( # finished: progress == 1 fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"), - # failed: progress == -1 fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"), - # processing: 0 <= progress < 1 fn.COALESCE( fn.SUM( @@ -931,24 +911,15 @@ class DocumentService(CommonService): 0, ).alias("processing"), ) - .where( - (cls.model.kb_id == kb_id) - & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL)) - ) + .where((cls.model.kb_id == kb_id) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))) .dicts() .get() ) - return { - "processing": int(row["processing"]), - "finished": int(row["finished"]), - "failed": int(row["failed"]), - "cancelled": int(cancelled), - "downloaded": int(downloaded) - } + return {"processing": int(row["processing"]), "finished": int(row["finished"]), "failed": int(row["failed"]), "cancelled": int(cancelled), "downloaded": int(downloaded)} @classmethod - def run(cls, tenant_id:str, doc:dict, kb_table_num_map:dict): + def run(cls, tenant_id: str, doc: dict, kb_table_num_map: dict): from api.db.services.task_service import queue_dataflow, queue_tasks from api.db.services.file2document_service import File2DocumentService @@ -990,7 +961,7 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d "from_page": 100000000, "to_page": 100000000, "task_type": ty, - "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty, + "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty, "begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } @@ -1032,8 +1003,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): e, dia = DialogService.get_by_id(conv.dialog_id) if not dia.kb_ids: - raise LookupError("No dataset associated with this conversation. " - "Please add a dataset before uploading documents") + raise LookupError("No dataset associated with this conversation. Please add a dataset before uploading documents") kb_id = dia.kb_ids[0] e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: @@ -1050,12 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): def dummy(prog=None, msg=""): pass - FACTORY = { - ParserType.PRESENTATION.value: presentation, - ParserType.PICTURE.value: picture, - ParserType.AUDIO.value: audio, - ParserType.EMAIL.value: email - } + FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email} parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0} exe = ThreadPoolExecutor(max_workers=12) threads = [] @@ -1063,22 +1028,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): for d, blob in files: doc_nm[d["id"]] = d["name"] for d, blob in files: - kwargs = { - "callback": dummy, - "parser_config": parser_config, - "from_page": 0, - "to_page": 100000, - "tenant_id": kb.tenant_id, - "lang": kb.language - } + kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language} threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) for (docinfo, _), th in zip(files, threads): docs = [] - doc = { - "doc_id": docinfo["id"], - "kb_id": [kb.id] - } + doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]} for ck in th.result(): d = deepcopy(doc) d.update(ck) @@ -1093,7 +1048,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): if isinstance(d["image"], bytes): output_buffer = BytesIO(d["image"]) else: - d["image"].save(output_buffer, format='JPEG') + d["image"].save(output_buffer, format="JPEG") settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue()) d["img_id"] = "{}-{}".format(kb.id, d["id"]) @@ -1110,9 +1065,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): nonlocal embd_mdl, chunk_counts, token_counts vectors = [] for i in range(0, len(cnts), batch_size): - vts, c = embd_mdl.encode(cnts[i: i + batch_size]) + vts, c = embd_mdl.encode(cnts[i : i + batch_size]) vectors.extend(vts.tolist()) - chunk_counts[doc_id] += len(cnts[i:i + batch_size]) + chunk_counts[doc_id] += len(cnts[i : i + batch_size]) token_counts[doc_id] += c return vectors @@ -1127,22 +1082,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): if parser_ids[doc_id] != ParserType.PICTURE.value: from rag.graphrag.general.mind_map_extractor import MindMapExtractor + mindmap = MindMapExtractor(llm_bdl) try: mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) if len(mind_map) < 32: raise Exception("Few content: " + mind_map) - cks.append({ - "id": get_uuid(), - "doc_id": doc_id, - "kb_id": [kb.id], - "docnm_kwd": doc_nm[doc_id], - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), - "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), - "content_with_weight": mind_map, - "knowledge_graph_kwd": "mind_map" - }) + cks.append( + { + "id": get_uuid(), + "doc_id": doc_id, + "kb_id": [kb.id], + "docnm_kwd": doc_nm[doc_id], + "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), + "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), + "content_with_weight": mind_map, + "knowledge_graph_kwd": "mind_map", + } + ) except Exception: logging.exception("Mind map generation error") @@ -1156,9 +1114,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): if not settings.docStoreConn.index_exist(idxnm, kb_id): settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id) try_create_idx = False - settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) + settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id) - DocumentService.increment_chunk_num( - doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) + DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) return [d["id"] for d, _ in files] diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 9e0b39aae..5864e6b4d 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -34,7 +34,9 @@ from werkzeug.exceptions import BadRequest, UnsupportedMediaType from api.constants import DATASET_NAME_LIMIT -async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: +async def validate_and_parse_json_request( + request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False +) -> tuple[dict[str, Any] | None, str | None]: """ Validates and parses JSON requests through a multi-stage validation pipeline. @@ -742,4 +744,5 @@ class BaseListReq(BaseModel): return validate_uuid1_hex(v) -class ListDatasetReq(BaseListReq): ... +class ListDatasetReq(BaseListReq): + include_parsing_status: Annotated[bool, Field(default=False)] diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 1f50d5753..8e7199b11 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -835,14 +835,14 @@ Failure: ### List datasets -**GET** `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}` +**GET** `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}&include_parsing_status={include_parsing_status}` Lists datasets. #### Request - Method: GET -- URL: `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}` +- URL: `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}&include_parsing_status={include_parsing_status}` - Headers: - `'Authorization: Bearer '` @@ -854,6 +854,13 @@ curl --request GET \ --header 'Authorization: Bearer ' ``` +```bash +# List datasets with parsing status +curl --request GET \ + --url 'http://{address}/api/v1/datasets?include_parsing_status=true' \ + --header 'Authorization: Bearer ' +``` + ##### Request parameters - `page`: (*Filter parameter*) @@ -870,6 +877,13 @@ curl --request GET \ The name of the dataset to retrieve. - `id`: (*Filter parameter*) The ID of the dataset to retrieve. +- `include_parsing_status`: (*Filter parameter*) + Whether to include document parsing status counts in the response. Defaults to `false`. When set to `true`, each dataset object in the response will include the following additional fields: + - `unstart_count`: Number of documents not yet started parsing. + - `running_count`: Number of documents currently being parsed. + - `cancel_count`: Number of documents whose parsing was cancelled. + - `done_count`: Number of documents that have been successfully parsed. + - `fail_count`: Number of documents whose parsing failed. #### Response @@ -917,6 +931,49 @@ Success: } ``` +Success (with `include_parsing_status=true`): + +```json +{ + "code": 0, + "data": [ + { + "avatar": null, + "cancel_count": 0, + "chunk_count": 30, + "chunk_method": "qa", + "create_date": "2026-03-09T18:57:13", + "create_time": 1773053833094, + "created_by": "928f92a210b911f1ac4cc39e0b8fa3ad", + "description": null, + "document_count": 1, + "done_count": 1, + "embedding_model": "text-embedding-v2@Tongyi-Qianwen", + "fail_count": 0, + "id": "ba6586c21ba611f1a3dc476f0709e75e", + "language": "English", + "name": "Test Dataset", + "parser_config": { + "graphrag": { "use_graphrag": false }, + "llm_id": "deepseek-chat@DeepSeek", + "raptor": { "use_raptor": false } + }, + "permission": "me", + "running_count": 0, + "similarity_threshold": 0.2, + "status": "1", + "tenant_id": "928f92a210b911f1ac4cc39e0b8fa3ad", + "token_num": 1746, + "unstart_count": 0, + "update_date": "2026-03-09T18:59:32", + "update_time": 1773053972723, + "vector_similarity_weight": 0.3 + } + ], + "total_datasets": 1 +} +``` + Failure: ```json diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index 430e58a0f..8fa97f3d5 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -266,7 +266,8 @@ RAGFlow.list_datasets( orderby: str = "create_time", desc: bool = True, id: str = None, - name: str = None + name: str = None, + include_parsing_status: bool = False ) -> list[DataSet] ``` @@ -301,6 +302,16 @@ The ID of the dataset to retrieve. Defaults to `None`. The name of the dataset to retrieve. Defaults to `None`. +##### include_parsing_status: `bool` + +Whether to include document parsing status counts in each returned `DataSet` object. Defaults to `False`. When set to `True`, each `DataSet` object will include the following additional attributes: + +- `unstart_count`: `int` Number of documents not yet started parsing. +- `running_count`: `int` Number of documents currently being parsed. +- `cancel_count`: `int` Number of documents whose parsing was cancelled. +- `done_count`: `int` Number of documents that have been successfully parsed. +- `fail_count`: `int` Number of documents whose parsing failed. + #### Returns - Success: A list of `DataSet` objects. @@ -322,6 +333,13 @@ dataset = rag_object.list_datasets(id = "id_1") print(dataset[0]) ``` +##### List datasets with parsing status + +```python +for dataset in rag_object.list_datasets(include_parsing_status=True): + print(dataset.done_count, dataset.fail_count, dataset.running_count) +``` + --- ### Update dataset diff --git a/example/http/dataset_example.sh b/example/http/dataset_example.sh index 492d902d0..1d2e8fa68 100644 --- a/example/http/dataset_example.sh +++ b/example/http/dataset_example.sh @@ -41,6 +41,12 @@ curl --request GET \ --url http://127.0.0.1:9380/api/v1/datasets \ --header 'Authorization: Bearer ragflow-IzZmY1MGVhYTBhMjExZWZiYTdjMDI0Mm' +# List datasets with parsing status +echo -e "\n-- List datasets with parsing status" +curl --request GET \ + --url 'http://127.0.0.1:9380/api/v1/datasets?include_parsing_status=true' \ + --header 'Authorization: Bearer ragflow-IzZmY1MGVhYTBhMjExZWZiYTdjMDI0Mm' + # Delete datasets echo -e "\n-- Delete datasets" curl --request DELETE \ diff --git a/test/unit_test/api/db/services/test_document_service_get_parsing_status.py b/test/unit_test/api/db/services/test_document_service_get_parsing_status.py new file mode 100644 index 000000000..997fe6f86 --- /dev/null +++ b/test/unit_test/api/db/services/test_document_service_get_parsing_status.py @@ -0,0 +1,326 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +import types +import warnings + +import pytest + +# xgboost imports pkg_resources and emits a deprecation warning that is promoted +# to error in our pytest configuration; ignore it for this unit test module. +warnings.filterwarnings( + "ignore", + message="pkg_resources is deprecated as an API.*", + category=UserWarning, +) + + +def _install_cv2_stub_if_unavailable(): + try: + import cv2 # noqa: F401 + return + except Exception: + pass + + stub = types.ModuleType("cv2") + + stub.INTER_LINEAR = 1 + stub.INTER_CUBIC = 2 + stub.BORDER_CONSTANT = 0 + stub.BORDER_REPLICATE = 1 + stub.COLOR_BGR2RGB = 0 + stub.COLOR_BGR2GRAY = 1 + stub.COLOR_GRAY2BGR = 2 + stub.IMREAD_IGNORE_ORIENTATION = 128 + stub.IMREAD_COLOR = 1 + stub.RETR_LIST = 1 + stub.CHAIN_APPROX_SIMPLE = 2 + + def _missing(*_args, **_kwargs): + raise RuntimeError("cv2 runtime call is unavailable in this test environment") + + def _module_getattr(name): + if name.isupper(): + return 0 + return _missing + + stub.__getattr__ = _module_getattr + sys.modules["cv2"] = stub + + +_install_cv2_stub_if_unavailable() + +from api.db.services.document_service import DocumentService # noqa: E402 +from common.constants import TaskStatus # noqa: E402 + +# --------------------------------------------------------------------------- +# Helpers to access the original function bypassing @DB.connection_context() +# --------------------------------------------------------------------------- + +def _unwrapped_get_parsing_status(): + """Return the original (un-decorated) get_parsing_status_by_kb_ids function. + + @classmethod + @DB.connection_context() together means: + DocumentService.get_parsing_status_by_kb_ids.__func__ -> connection_context wrapper + ....__func__.__wrapped__ -> original function + """ + return DocumentService.get_parsing_status_by_kb_ids.__func__.__wrapped__ + + +# --------------------------------------------------------------------------- +# Fake ORM helpers – mimic the minimal peewee query chain used by the function +# --------------------------------------------------------------------------- + +class _FieldStub: + """Minimal stand-in for a peewee model field used in select/where/group_by.""" + + def in_(self, values): + """Called by .where(cls.model.kb_id.in_(kb_ids)) – no-op in tests.""" + return self + + def alias(self, name): + return self + + +class _FakeQuery: + """Chains .where(), .group_by(), .dicts() without touching a real database.""" + + def __init__(self, rows): + self._rows = rows + + def where(self, *_args, **_kwargs): + return self + + def group_by(self, *_args, **_kwargs): + return self + + def dicts(self): + return list(self._rows) + + +def _make_fake_model(rows): + """Create a fake Document model class whose select() returns *rows*.""" + + class _FakeModel: + id = _FieldStub() + kb_id = _FieldStub() + run = _FieldStub() + + @classmethod + def select(cls, *_args): + return _FakeQuery(rows) + + return _FakeModel + + +# --------------------------------------------------------------------------- +# Pytest fixture – patch DocumentService.model per test +# --------------------------------------------------------------------------- + +@pytest.fixture() +def call_with_rows(monkeypatch): + """Return a helper that runs get_parsing_status_by_kb_ids with fake DB rows.""" + + def _call(rows, kb_ids): + monkeypatch.setattr(DocumentService, "model", _make_fake_model(rows)) + fn = _unwrapped_get_parsing_status() + return fn(DocumentService, kb_ids) + + return _call + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +_ALL_STATUS_FIELDS = frozenset( + ["unstart_count", "running_count", "cancel_count", "done_count", "fail_count"] +) + + +@pytest.mark.p2 +class TestGetParsingStatusByKbIds: + + # ------------------------------------------------------------------ + # Edge-case: empty input list – must short-circuit before any DB call + # ------------------------------------------------------------------ + + def test_empty_kb_ids_returns_empty_dict(self, call_with_rows): + result = call_with_rows([], []) + assert result == {} + + # ------------------------------------------------------------------ + # A kb_id present in the input but with no matching documents + # ------------------------------------------------------------------ + + def test_single_kb_id_no_documents(self, call_with_rows): + result = call_with_rows(rows=[], kb_ids=["kb-1"]) + + assert set(result.keys()) == {"kb-1"} + assert set(result["kb-1"].keys()) == _ALL_STATUS_FIELDS + assert all(v == 0 for v in result["kb-1"].values()) + + # ------------------------------------------------------------------ + # A single kb_id with one document in each run-status bucket + # ------------------------------------------------------------------ + + def test_single_kb_id_all_five_statuses(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": TaskStatus.UNSTART.value, "cnt": 3}, + {"kb_id": "kb-1", "run": TaskStatus.RUNNING.value, "cnt": 1}, + {"kb_id": "kb-1", "run": TaskStatus.CANCEL.value, "cnt": 2}, + {"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 10}, + {"kb_id": "kb-1", "run": TaskStatus.FAIL.value, "cnt": 4}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["unstart_count"] == 3 + assert result["kb-1"]["running_count"] == 1 + assert result["kb-1"]["cancel_count"] == 2 + assert result["kb-1"]["done_count"] == 10 + assert result["kb-1"]["fail_count"] == 4 + + # ------------------------------------------------------------------ + # Two kb_ids – counts must be independent per dataset + # ------------------------------------------------------------------ + + def test_multiple_kb_ids_aggregated_separately(self, call_with_rows): + rows = [ + {"kb_id": "kb-a", "run": TaskStatus.DONE.value, "cnt": 5}, + {"kb_id": "kb-a", "run": TaskStatus.FAIL.value, "cnt": 1}, + {"kb_id": "kb-b", "run": TaskStatus.UNSTART.value, "cnt": 7}, + {"kb_id": "kb-b", "run": TaskStatus.DONE.value, "cnt": 2}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-a", "kb-b"]) + + assert set(result.keys()) == {"kb-a", "kb-b"} + + assert result["kb-a"]["done_count"] == 5 + assert result["kb-a"]["fail_count"] == 1 + assert result["kb-a"]["unstart_count"] == 0 + assert result["kb-a"]["running_count"] == 0 + assert result["kb-a"]["cancel_count"] == 0 + + assert result["kb-b"]["unstart_count"] == 7 + assert result["kb-b"]["done_count"] == 2 + assert result["kb-b"]["fail_count"] == 0 + + # ------------------------------------------------------------------ + # An unrecognised run value must be silently ignored + # ------------------------------------------------------------------ + + def test_unknown_run_value_ignored(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": "9", "cnt": 99}, # "9" is not a TaskStatus + {"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 4}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["done_count"] == 4 + assert all( + result["kb-1"][f] == 0 + for f in _ALL_STATUS_FIELDS - {"done_count"} + ) + + # ------------------------------------------------------------------ + # A row whose kb_id was NOT requested must not appear in the output + # ------------------------------------------------------------------ + + def test_row_with_unrequested_kb_id_is_filtered_out(self, call_with_rows): + rows = [ + {"kb_id": "kb-requested", "run": TaskStatus.DONE.value, "cnt": 3}, + {"kb_id": "kb-unexpected", "run": TaskStatus.DONE.value, "cnt": 100}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-requested"]) + + assert "kb-unexpected" not in result + assert result["kb-requested"]["done_count"] == 3 + + # ------------------------------------------------------------------ + # cnt values must be treated as integers regardless of DB type hints + # ------------------------------------------------------------------ + + def test_cnt_is_cast_to_int(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": TaskStatus.RUNNING.value, "cnt": "7"}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["running_count"] == 7 + assert isinstance(result["kb-1"]["running_count"], int) + + # ------------------------------------------------------------------ + # run value stored as integer in DB (some adapters may omit str cast) + # ------------------------------------------------------------------ + + def test_run_value_as_integer_is_handled(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": int(TaskStatus.DONE.value), "cnt": 5}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["done_count"] == 5 + + # ------------------------------------------------------------------ + # All five status fields are initialised to 0 even when no rows exist + # ------------------------------------------------------------------ + + def test_all_five_fields_initialised_to_zero(self, call_with_rows): + result = call_with_rows(rows=[], kb_ids=["kb-empty"]) + + assert result["kb-empty"] == { + "unstart_count": 0, + "running_count": 0, + "cancel_count": 0, + "done_count": 0, + "fail_count": 0, + } + + # ------------------------------------------------------------------ + # Multiple kb_ids in the input – all should appear in the result + # even when no documents exist for some of them + # ------------------------------------------------------------------ + + def test_requested_kb_ids_all_present_in_result(self, call_with_rows): + rows = [ + {"kb_id": "kb-with-data", "run": TaskStatus.DONE.value, "cnt": 1}, + ] + result = call_with_rows( + rows=rows, kb_ids=["kb-with-data", "kb-empty-1", "kb-empty-2"] + ) + + assert set(result.keys()) == {"kb-with-data", "kb-empty-1", "kb-empty-2"} + assert result["kb-empty-1"] == {f: 0 for f in _ALL_STATUS_FIELDS} + assert result["kb-empty-2"] == {f: 0 for f in _ALL_STATUS_FIELDS} + + # ------------------------------------------------------------------ + # SCHEDULE (run=="5") is not mapped – must be silently ignored + # ------------------------------------------------------------------ + + def test_schedule_status_is_not_mapped(self, call_with_rows): + rows = [ + {"kb_id": "kb-1", "run": TaskStatus.SCHEDULE.value, "cnt": 3}, + {"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 2}, + ] + result = call_with_rows(rows=rows, kb_ids=["kb-1"]) + + assert result["kb-1"]["done_count"] == 2 + # SCHEDULE is not a tracked bucket + assert "schedule_count" not in result["kb-1"] + assert all( + result["kb-1"][f] == 0 + for f in _ALL_STATUS_FIELDS - {"done_count"} + )