Feat: Support get aggregated parsing status to dataset via the API (#13481)

### What problem does this PR solve?

Support getting aggregated parsing status to dataset via the API

Issue: #12810

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: heyang.why <heyang.why@alibaba-inc.com>
This commit is contained in:
Heyang Wang
2026-03-10 18:05:45 +08:00
committed by GitHub
parent 68a623154a
commit 08f83ff331
7 changed files with 654 additions and 309 deletions

View File

@ -138,12 +138,7 @@ async def create(tenant_id):
parser_cfg["metadata"] = fields
parser_cfg["enable_metadata"] = auto_meta.get("enabled", True)
req["parser_config"] = parser_cfg
e, req = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = tenant_id,
parser_id = req.pop("parser_id", None),
**req
)
e, req = KnowledgebaseService.create_with_name(name=req.pop("name", None), tenant_id=tenant_id, parser_id=req.pop("parser_id", None), **req)
if not e:
return req
@ -159,19 +154,19 @@ async def create(tenant_id):
if not ok:
return err
try:
if not KnowledgebaseService.save(**req):
return get_error_data_result()
ok, k = KnowledgebaseService.get_by_id(req["id"])
if not ok:
return get_error_data_result(message="Dataset created failed")
response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)
if not KnowledgebaseService.save(**req):
return get_error_data_result()
ok, k = KnowledgebaseService.get_by_id(req["id"])
if not ok:
return get_error_data_result(message="Dataset created failed")
response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
@token_required
async def delete(tenant_id):
@ -227,8 +222,7 @@ async def delete(tenant_id):
continue
kb_id_instance_pairs.append((kb_id, kb))
if len(error_kb_ids) > 0:
return get_error_permission_result(
message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
errors = []
success_count = 0
@ -245,12 +239,12 @@ async def delete(tenant_id):
]
)
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
# Drop index for this dataset
try:
from rag.nlp import search
idxnm = search.index_name(kb.tenant_id)
settings.docStoreConn.delete_idx(idxnm, kb_id)
except Exception as e:
@ -352,8 +346,7 @@ async def update(tenant_id, dataset_id):
try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None:
return get_error_permission_result(
message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
# Map auto_metadata_config into parser_config if present
auto_meta = req.pop("auto_metadata_config", None)
@ -384,8 +377,7 @@ async def update(tenant_id, dataset_id):
del req["parser_config"]
if "name" in req and req["name"].lower() != kb.name.lower():
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
if exists:
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
@ -393,8 +385,7 @@ async def update(tenant_id, dataset_id):
if not req["embd_id"]:
req["embd_id"] = kb.embd_id
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
return get_error_data_result(
message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
if not ok:
return err
@ -404,12 +395,10 @@ async def update(tenant_id, dataset_id):
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
if req["pagerank"] > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)")
@ -470,6 +459,15 @@ def list_datasets(tenant_id):
required: false
default: true
description: Order in descending.
- in: query
name: include_parsing_status
type: boolean
required: false
default: false
description: |
Whether to include document parsing status counts in the response.
When true, each dataset object will include: unstart_count, running_count,
cancel_count, done_count, and fail_count.
- in: header
name: Authorization
type: string
@ -487,17 +485,18 @@ def list_datasets(tenant_id):
if err is not None:
return get_error_argument_result(err)
include_parsing_status = args.get("include_parsing_status", False)
try:
kb_id = request.args.get("id")
name = args.get("name")
# check whether user has permission for the dataset with specified id
if kb_id:
kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id)
if not kbs:
if not KnowledgebaseService.get_kb_by_id(kb_id, tenant_id):
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'")
# check whether user has permission for the dataset with specified name
if name:
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
if not kbs:
if not KnowledgebaseService.get_kb_by_name(name, tenant_id):
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
@ -512,9 +511,17 @@ def list_datasets(tenant_id):
name,
)
parsing_status_map = {}
if include_parsing_status and kbs:
kb_ids = [kb["id"] for kb in kbs]
parsing_status_map = DocumentService.get_parsing_status_by_kb_ids(kb_ids)
response_data_list = []
for kb in kbs:
response_data_list.append(remap_dictionary_keys(kb))
data = remap_dictionary_keys(kb)
if include_parsing_status:
data.update(parsing_status_map.get(kb["id"], {}))
response_data_list.append(data)
return get_result(data=response_data_list, total=total)
except OperationalError as e:
logging.exception(e)
@ -530,9 +537,7 @@ def get_auto_metadata(tenant_id, dataset_id):
try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None:
return get_error_permission_result(
message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
)
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
parser_cfg = kb.parser_config or {}
metadata = parser_cfg.get("metadata") or []
@ -570,9 +575,7 @@ async def update_auto_metadata(tenant_id, dataset_id):
try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None:
return get_error_permission_result(
message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
)
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
parser_cfg = kb.parser_config or {}
fields = []
@ -598,20 +601,13 @@ async def update_auto_metadata(tenant_id, dataset_id):
return get_error_data_result(message="Database operation failed")
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
@manager.route("/datasets/<dataset_id>/knowledge_graph", methods=["GET"]) # noqa: F821
@token_required
async def knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
_, kb = KnowledgebaseService.get_by_id(dataset_id)
req = {
"kb_id": [dataset_id],
"knowledge_graph_kwd": ["graph"]
}
req = {"kb_id": [dataset_id], "knowledge_graph_kwd": ["graph"]}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
@ -633,39 +629,29 @@ async def knowledge_graph(tenant_id, dataset_id):
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
if "edges" in obj["graph"]:
node_id_set = {o["id"] for o in obj["graph"]["nodes"]}
filtered_edges = [o for o in obj["graph"]["edges"] if
o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
return get_result(data=obj)
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
@manager.route("/datasets/<dataset_id>/knowledge_graph", methods=["DELETE"]) # noqa: F821
@token_required
def delete_knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
_, kb = KnowledgebaseService.get_by_id(dataset_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
search.index_name(kb.tenant_id), dataset_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True)
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
@token_required
def run_graphrag(tenant_id,dataset_id):
def run_graphrag(tenant_id, dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
@ -707,15 +693,11 @@ def run_graphrag(tenant_id,dataset_id):
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
@token_required
def trace_graphrag(tenant_id,dataset_id):
def trace_graphrag(tenant_id, dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
@ -734,15 +716,11 @@ def trace_graphrag(tenant_id,dataset_id):
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
@token_required
def run_raptor(tenant_id,dataset_id):
def run_raptor(tenant_id, dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
@ -784,7 +762,7 @@ def run_raptor(tenant_id,dataset_id):
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
@token_required
def trace_raptor(tenant_id,dataset_id):
def trace_raptor(tenant_id, dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')

View File

@ -28,8 +28,7 @@ from peewee import fn, Case, JOIN
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES, FileType, UserTenantRole, CanvasCategory
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
User
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, User
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -78,24 +77,21 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id, name, suffix=None, run = None, doc_ids=None):
def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name, suffix=None, run=None, doc_ids=None):
fields = cls.get_cls_model_fields()
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
.join(File, on = (File.id == File2Document.file_id))\
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
docs = (
cls.model.select(*[*fields, UserCanvas.title])
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.join(UserCanvas, on=((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)
.where(cls.model.kb_id == kb_id)
)
if id:
docs = docs.where(
cls.model.id == id)
docs = docs.where(cls.model.id == id)
if name:
docs = docs.where(
cls.model.name == name
)
docs = docs.where(cls.model.name == name)
if keywords:
docs = docs.where(
fn.LOWER(cls.model.name).contains(keywords.lower())
)
docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if doc_ids:
docs = docs.where(cls.model.id.in_(doc_ids))
if suffix:
@ -120,6 +116,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def check_doc_health(cls, tenant_id: str, filename):
import os
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
raise RuntimeError("Exceed the maximum file number of a free user!")
@ -211,13 +208,14 @@ class DocumentService(CommonService):
"""
fields = cls.get_cls_model_fields()
if keywords:
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
query = (
cls.model.select(*fields)
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
)
else:
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
if run_status:
query = query.where(cls.model.run.in_(run_status))
@ -272,14 +270,60 @@ class DocumentService(CommonService):
"metadata": metadata_counter,
}, total
@classmethod
@DB.connection_context()
def get_parsing_status_by_kb_ids(cls, kb_ids: list[str]) -> dict[str, dict[str, int]]:
"""Return aggregated document parsing status counts grouped by dataset (kb_id).
For each kb_id, counts documents in each run-status bucket:
- unstart_count (run == "0")
- running_count (run == "1")
- cancel_count (run == "2")
- done_count (run == "3")
- fail_count (run == "4")
Returns a dict keyed by kb_id, e.g.
{"kb-abc": {"unstart_count": 10, "running_count": 2, ...}, ...}
"""
if not kb_ids:
return {}
status_field_map = {
TaskStatus.UNSTART.value: "unstart_count",
TaskStatus.RUNNING.value: "running_count",
TaskStatus.CANCEL.value: "cancel_count",
TaskStatus.DONE.value: "done_count",
TaskStatus.FAIL.value: "fail_count",
}
empty_status = {v: 0 for v in status_field_map.values()}
result: dict[str, dict[str, int]] = {kb_id: dict(empty_status) for kb_id in kb_ids}
rows = (
cls.model.select(
cls.model.kb_id,
cls.model.run,
fn.COUNT(cls.model.id).alias("cnt"),
)
.where(cls.model.kb_id.in_(kb_ids))
.group_by(cls.model.kb_id, cls.model.run)
.dicts()
)
for row in rows:
kb_id = row["kb_id"]
run_val = str(row["run"])
field_name = status_field_map.get(run_val)
if field_name and kb_id in result:
result[kb_id][field_name] = int(row["cnt"])
return result
@classmethod
@DB.connection_context()
def count_by_kb_id(cls, kb_id, keywords, run_status, types):
if keywords:
docs = cls.model.select().where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
@ -295,9 +339,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]):
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(
cls.model.kb_id == kb_id
)
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(cls.model.kb_id == kb_id)
if keywords:
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
@ -329,12 +371,8 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_all_docs_by_creator_id(cls, creator_id):
fields = [
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
cls.model.created_by == creator_id
)
fields = [cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.created_by == creator_id)
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
@ -361,6 +399,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
from api.db.services.task_service import TaskService, cancel_all_task_of
if not cls.delete_document_and_update_kb_counts(doc.id):
return True
@ -406,17 +445,22 @@ class DocumentService(CommonService):
# Cleanup knowledge graph references (non-critical, log and continue)
try:
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]),
["source_id"],
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id),
doc.kb_id,
)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id),
doc.kb_id,
)
except Exception as e:
logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
@ -428,9 +472,7 @@ class DocumentService(CommonService):
page = 0
page_size = 1000
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id])
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
@ -455,71 +497,61 @@ class DocumentService(CommonService):
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
cls.model.update_time,
]
docs = (
cls.model.select(*fields)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value) \
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value,
)
.order_by(cls.model.update_time.asc())
)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run, cls.model.parser_id]
unfinished_task_query = Task.select(Task.doc_id).where(
(Task.progress >= 0) & (Task.progress < 1)
)
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id]
unfinished_task_query = Task.select(Task.doc_id).where((Task.progress >= 0) & (Task.progress < 1))
docs = cls.model.select(*fields) \
.where(
docs = cls.model.select(*fields).where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
(((cls.model.progress < 1) & (cls.model.progress > 0)) | (cls.model.id.in_(unfinished_task_query))),
) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duration=cls.model.process_duration + duration).where(
cls.model.id == doc_id).execute()
num = (
cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duration=cls.model.process_duration + duration)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
logging.warning("Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = cls.model.update(token_num=cls.model.token_num - token_num,
chunk_num=cls.model.chunk_num - chunk_num,
process_duration=cls.model.process_duration + duration).where(
cls.model.id == doc_id).execute()
num = (
cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duration=cls.model.process_duration + duration)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@ -551,17 +583,13 @@ class DocumentService(CommonService):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1
).where(
Knowledgebase.id == doc.kb_id).execute()
num = (
Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1)
.where(Knowledgebase.id == doc.kb_id)
.execute()
)
return num
@classmethod
@DB.connection_context()
def clear_chunk_num_when_rerun(cls, doc_id):
@ -578,15 +606,10 @@ class DocumentService(CommonService):
)
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
@ -604,11 +627,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_tenant_id_by_name(cls, name):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
@ -617,12 +636,13 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = (
cls.model.select(cls.model.id)
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
.join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id))
.where(cls.model.id == doc_id, UserTenant.user_id == user_id)
.paginate(0, 1)
)
docs = docs.dicts()
if not docs:
return False
@ -631,18 +651,13 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(cls.model.id
).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(
UserTenant, on=(
(UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
).where(
cls.model.id == doc_id,
UserTenant.status == StatusEnum.VALID.value,
((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))
).paginate(0, 1)
docs = (
cls.model.select(cls.model.id)
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
.join(UserTenant, on=((UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id)))
.where(cls.model.id == doc_id, UserTenant.status == StatusEnum.VALID.value, ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER)))
.paginate(0, 1)
)
docs = docs.dicts()
if not docs:
return False
@ -651,11 +666,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
@ -664,11 +675,9 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_tenant_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = (
cls.model.select(Knowledgebase.tenant_embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
)
docs = docs.dicts()
if not docs:
return None
@ -705,8 +714,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
doc_id = cls.model.select(*fields).where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return None
@ -725,8 +733,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
@ -755,9 +762,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase,
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id)
return len(docs)
@classmethod
@ -768,7 +773,7 @@ class DocumentService(CommonService):
"process_begin_at": get_format_time(),
}
if not keep_progress:
info["progress"] = random.random() * 1 / 100.
info["progress"] = random.random() * 1 / 100.0
info["run"] = TaskStatus.RUNNING.value
# keep the doc in DONE state when keep_progress=True for GraphRAG, RAPTOR and Mindmap tasks
@ -781,19 +786,17 @@ class DocumentService(CommonService):
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def update_progress_immediately(cls, docs:list[dict]):
def update_progress_immediately(cls, docs: list[dict]):
if not docs:
return
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def _sync_progress(cls, docs:list[dict]):
def _sync_progress(cls, docs: list[dict]):
from api.db.services.task_service import TaskService
for d in docs:
@ -841,27 +844,18 @@ class DocumentService(CommonService):
# fallback
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
info = {
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
"run": status}
info = {"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), "run": status}
if prg != 0 and not freeze_progress:
info["progress"] = prg
if msg:
info["progress_msg"] = msg
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
info["progress_msg"] += "\n%d tasks are ahead in the queue..." % get_queue_length(priority)
else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
info["progress_msg"] = "%d tasks are ahead in the queue..." % get_queue_length(priority)
info["update_time"] = current_timestamp()
info["update_date"] = get_format_time()
(
cls.model.update(info)
.where(
(cls.model.id == d["id"])
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
)
.execute()
)
(cls.model.update(info).where((cls.model.id == d["id"]) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))).execute())
except Exception as e:
if str(e).find("'0'") < 0:
logging.exception("fetch task exception")
@ -875,7 +869,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_all_kb_doc_count(cls):
result = {}
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias("count")).group_by(cls.model.kb_id)
for row in rows:
result[row.kb_id] = row.count
return result
@ -890,33 +884,19 @@ class DocumentService(CommonService):
pass
return False
@classmethod
@DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
# cancelled: run == "2"
cancelled = (
cls.model.select(fn.COUNT(1))
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
.scalar()
)
downloaded = (
cls.model.select(fn.COUNT(1))
.where(
cls.model.kb_id == kb_id,
cls.model.source_type != "local"
)
.scalar()
)
cancelled = cls.model.select(fn.COUNT(1)).where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)).scalar()
downloaded = cls.model.select(fn.COUNT(1)).where(cls.model.kb_id == kb_id, cls.model.source_type != "local").scalar()
row = (
cls.model.select(
# finished: progress == 1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"),
# failed: progress == -1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"),
# processing: 0 <= progress < 1
fn.COALESCE(
fn.SUM(
@ -931,24 +911,15 @@ class DocumentService(CommonService):
0,
).alias("processing"),
)
.where(
(cls.model.kb_id == kb_id)
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))
)
.where((cls.model.kb_id == kb_id) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL)))
.dicts()
.get()
)
return {
"processing": int(row["processing"]),
"finished": int(row["finished"]),
"failed": int(row["failed"]),
"cancelled": int(cancelled),
"downloaded": int(downloaded)
}
return {"processing": int(row["processing"]), "finished": int(row["finished"]), "failed": int(row["failed"]), "cancelled": int(cancelled), "downloaded": int(downloaded)}
@classmethod
def run(cls, tenant_id:str, doc:dict, kb_table_num_map:dict):
def run(cls, tenant_id: str, doc: dict, kb_table_num_map: dict):
from api.db.services.task_service import queue_dataflow, queue_tasks
from api.db.services.file2document_service import File2DocumentService
@ -990,7 +961,7 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d
"from_page": 100000000,
"to_page": 100000000,
"task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
@ -1032,8 +1003,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
e, dia = DialogService.get_by_id(conv.dialog_id)
if not dia.kb_ids:
raise LookupError("No dataset associated with this conversation. "
"Please add a dataset before uploading documents")
raise LookupError("No dataset associated with this conversation. Please add a dataset before uploading documents")
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
@ -1050,12 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
def dummy(prog=None, msg=""):
pass
FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
@ -1063,22 +1028,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
for d, blob in files:
doc_nm[d["id"]] = d["name"]
for d, blob in files:
kwargs = {
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": kb.tenant_id,
"lang": kb.language
}
kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
for (docinfo, _), th in zip(files, threads):
docs = []
doc = {
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
@ -1093,7 +1048,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
d["image"].save(output_buffer, format="JPEG")
settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["id"])
@ -1110,9 +1065,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
nonlocal embd_mdl, chunk_counts, token_counts
vectors = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vts, c = embd_mdl.encode(cnts[i : i + batch_size])
vectors.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
chunk_counts[doc_id] += len(cnts[i : i + batch_size])
token_counts[doc_id] += c
return vectors
@ -1127,22 +1082,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if parser_ids[doc_id] != ParserType.PICTURE.value:
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map"
})
cks.append(
{
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map",
}
)
except Exception:
logging.exception("Mind map generation error")
@ -1156,9 +1114,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d, _ in files]

View File

@ -34,7 +34,9 @@ from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.constants import DATASET_NAME_LIMIT
async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
async def validate_and_parse_json_request(
request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False
) -> tuple[dict[str, Any] | None, str | None]:
"""
Validates and parses JSON requests through a multi-stage validation pipeline.
@ -742,4 +744,5 @@ class BaseListReq(BaseModel):
return validate_uuid1_hex(v)
class ListDatasetReq(BaseListReq): ...
class ListDatasetReq(BaseListReq):
include_parsing_status: Annotated[bool, Field(default=False)]

View File

@ -835,14 +835,14 @@ Failure:
### List datasets
**GET** `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}`
**GET** `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}&include_parsing_status={include_parsing_status}`
Lists datasets.
#### Request
- Method: GET
- URL: `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}`
- URL: `/api/v1/datasets?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={dataset_name}&id={dataset_id}&include_parsing_status={include_parsing_status}`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
@ -854,6 +854,13 @@ curl --request GET \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
```bash
# List datasets with parsing status
curl --request GET \
--url 'http://{address}/api/v1/datasets?include_parsing_status=true' \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `page`: (*Filter parameter*)
@ -870,6 +877,13 @@ curl --request GET \
The name of the dataset to retrieve.
- `id`: (*Filter parameter*)
The ID of the dataset to retrieve.
- `include_parsing_status`: (*Filter parameter*)
Whether to include document parsing status counts in the response. Defaults to `false`. When set to `true`, each dataset object in the response will include the following additional fields:
- `unstart_count`: Number of documents not yet started parsing.
- `running_count`: Number of documents currently being parsed.
- `cancel_count`: Number of documents whose parsing was cancelled.
- `done_count`: Number of documents that have been successfully parsed.
- `fail_count`: Number of documents whose parsing failed.
#### Response
@ -917,6 +931,49 @@ Success:
}
```
Success (with `include_parsing_status=true`):
```json
{
"code": 0,
"data": [
{
"avatar": null,
"cancel_count": 0,
"chunk_count": 30,
"chunk_method": "qa",
"create_date": "2026-03-09T18:57:13",
"create_time": 1773053833094,
"created_by": "928f92a210b911f1ac4cc39e0b8fa3ad",
"description": null,
"document_count": 1,
"done_count": 1,
"embedding_model": "text-embedding-v2@Tongyi-Qianwen",
"fail_count": 0,
"id": "ba6586c21ba611f1a3dc476f0709e75e",
"language": "English",
"name": "Test Dataset",
"parser_config": {
"graphrag": { "use_graphrag": false },
"llm_id": "deepseek-chat@DeepSeek",
"raptor": { "use_raptor": false }
},
"permission": "me",
"running_count": 0,
"similarity_threshold": 0.2,
"status": "1",
"tenant_id": "928f92a210b911f1ac4cc39e0b8fa3ad",
"token_num": 1746,
"unstart_count": 0,
"update_date": "2026-03-09T18:59:32",
"update_time": 1773053972723,
"vector_similarity_weight": 0.3
}
],
"total_datasets": 1
}
```
Failure:
```json

View File

@ -266,7 +266,8 @@ RAGFlow.list_datasets(
orderby: str = "create_time",
desc: bool = True,
id: str = None,
name: str = None
name: str = None,
include_parsing_status: bool = False
) -> list[DataSet]
```
@ -301,6 +302,16 @@ The ID of the dataset to retrieve. Defaults to `None`.
The name of the dataset to retrieve. Defaults to `None`.
##### include_parsing_status: `bool`
Whether to include document parsing status counts in each returned `DataSet` object. Defaults to `False`. When set to `True`, each `DataSet` object will include the following additional attributes:
- `unstart_count`: `int` Number of documents not yet started parsing.
- `running_count`: `int` Number of documents currently being parsed.
- `cancel_count`: `int` Number of documents whose parsing was cancelled.
- `done_count`: `int` Number of documents that have been successfully parsed.
- `fail_count`: `int` Number of documents whose parsing failed.
#### Returns
- Success: A list of `DataSet` objects.
@ -322,6 +333,13 @@ dataset = rag_object.list_datasets(id = "id_1")
print(dataset[0])
```
##### List datasets with parsing status
```python
for dataset in rag_object.list_datasets(include_parsing_status=True):
print(dataset.done_count, dataset.fail_count, dataset.running_count)
```
---
### Update dataset

View File

@ -41,6 +41,12 @@ curl --request GET \
--url http://127.0.0.1:9380/api/v1/datasets \
--header 'Authorization: Bearer ragflow-IzZmY1MGVhYTBhMjExZWZiYTdjMDI0Mm'
# List datasets with parsing status
echo -e "\n-- List datasets with parsing status"
curl --request GET \
--url 'http://127.0.0.1:9380/api/v1/datasets?include_parsing_status=true' \
--header 'Authorization: Bearer ragflow-IzZmY1MGVhYTBhMjExZWZiYTdjMDI0Mm'
# Delete datasets
echo -e "\n-- Delete datasets"
curl --request DELETE \

View File

@ -0,0 +1,326 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import types
import warnings
import pytest
# xgboost imports pkg_resources and emits a deprecation warning that is promoted
# to error in our pytest configuration; ignore it for this unit test module.
warnings.filterwarnings(
"ignore",
message="pkg_resources is deprecated as an API.*",
category=UserWarning,
)
def _install_cv2_stub_if_unavailable():
try:
import cv2 # noqa: F401
return
except Exception:
pass
stub = types.ModuleType("cv2")
stub.INTER_LINEAR = 1
stub.INTER_CUBIC = 2
stub.BORDER_CONSTANT = 0
stub.BORDER_REPLICATE = 1
stub.COLOR_BGR2RGB = 0
stub.COLOR_BGR2GRAY = 1
stub.COLOR_GRAY2BGR = 2
stub.IMREAD_IGNORE_ORIENTATION = 128
stub.IMREAD_COLOR = 1
stub.RETR_LIST = 1
stub.CHAIN_APPROX_SIMPLE = 2
def _missing(*_args, **_kwargs):
raise RuntimeError("cv2 runtime call is unavailable in this test environment")
def _module_getattr(name):
if name.isupper():
return 0
return _missing
stub.__getattr__ = _module_getattr
sys.modules["cv2"] = stub
_install_cv2_stub_if_unavailable()
from api.db.services.document_service import DocumentService # noqa: E402
from common.constants import TaskStatus # noqa: E402
# ---------------------------------------------------------------------------
# Helpers to access the original function bypassing @DB.connection_context()
# ---------------------------------------------------------------------------
def _unwrapped_get_parsing_status():
"""Return the original (un-decorated) get_parsing_status_by_kb_ids function.
@classmethod + @DB.connection_context() together means:
DocumentService.get_parsing_status_by_kb_ids.__func__ -> connection_context wrapper
....__func__.__wrapped__ -> original function
"""
return DocumentService.get_parsing_status_by_kb_ids.__func__.__wrapped__
# ---------------------------------------------------------------------------
# Fake ORM helpers mimic the minimal peewee query chain used by the function
# ---------------------------------------------------------------------------
class _FieldStub:
"""Minimal stand-in for a peewee model field used in select/where/group_by."""
def in_(self, values):
"""Called by .where(cls.model.kb_id.in_(kb_ids)) no-op in tests."""
return self
def alias(self, name):
return self
class _FakeQuery:
"""Chains .where(), .group_by(), .dicts() without touching a real database."""
def __init__(self, rows):
self._rows = rows
def where(self, *_args, **_kwargs):
return self
def group_by(self, *_args, **_kwargs):
return self
def dicts(self):
return list(self._rows)
def _make_fake_model(rows):
"""Create a fake Document model class whose select() returns *rows*."""
class _FakeModel:
id = _FieldStub()
kb_id = _FieldStub()
run = _FieldStub()
@classmethod
def select(cls, *_args):
return _FakeQuery(rows)
return _FakeModel
# ---------------------------------------------------------------------------
# Pytest fixture patch DocumentService.model per test
# ---------------------------------------------------------------------------
@pytest.fixture()
def call_with_rows(monkeypatch):
"""Return a helper that runs get_parsing_status_by_kb_ids with fake DB rows."""
def _call(rows, kb_ids):
monkeypatch.setattr(DocumentService, "model", _make_fake_model(rows))
fn = _unwrapped_get_parsing_status()
return fn(DocumentService, kb_ids)
return _call
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
_ALL_STATUS_FIELDS = frozenset(
["unstart_count", "running_count", "cancel_count", "done_count", "fail_count"]
)
@pytest.mark.p2
class TestGetParsingStatusByKbIds:
# ------------------------------------------------------------------
# Edge-case: empty input list must short-circuit before any DB call
# ------------------------------------------------------------------
def test_empty_kb_ids_returns_empty_dict(self, call_with_rows):
result = call_with_rows([], [])
assert result == {}
# ------------------------------------------------------------------
# A kb_id present in the input but with no matching documents
# ------------------------------------------------------------------
def test_single_kb_id_no_documents(self, call_with_rows):
result = call_with_rows(rows=[], kb_ids=["kb-1"])
assert set(result.keys()) == {"kb-1"}
assert set(result["kb-1"].keys()) == _ALL_STATUS_FIELDS
assert all(v == 0 for v in result["kb-1"].values())
# ------------------------------------------------------------------
# A single kb_id with one document in each run-status bucket
# ------------------------------------------------------------------
def test_single_kb_id_all_five_statuses(self, call_with_rows):
rows = [
{"kb_id": "kb-1", "run": TaskStatus.UNSTART.value, "cnt": 3},
{"kb_id": "kb-1", "run": TaskStatus.RUNNING.value, "cnt": 1},
{"kb_id": "kb-1", "run": TaskStatus.CANCEL.value, "cnt": 2},
{"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 10},
{"kb_id": "kb-1", "run": TaskStatus.FAIL.value, "cnt": 4},
]
result = call_with_rows(rows=rows, kb_ids=["kb-1"])
assert result["kb-1"]["unstart_count"] == 3
assert result["kb-1"]["running_count"] == 1
assert result["kb-1"]["cancel_count"] == 2
assert result["kb-1"]["done_count"] == 10
assert result["kb-1"]["fail_count"] == 4
# ------------------------------------------------------------------
# Two kb_ids counts must be independent per dataset
# ------------------------------------------------------------------
def test_multiple_kb_ids_aggregated_separately(self, call_with_rows):
rows = [
{"kb_id": "kb-a", "run": TaskStatus.DONE.value, "cnt": 5},
{"kb_id": "kb-a", "run": TaskStatus.FAIL.value, "cnt": 1},
{"kb_id": "kb-b", "run": TaskStatus.UNSTART.value, "cnt": 7},
{"kb_id": "kb-b", "run": TaskStatus.DONE.value, "cnt": 2},
]
result = call_with_rows(rows=rows, kb_ids=["kb-a", "kb-b"])
assert set(result.keys()) == {"kb-a", "kb-b"}
assert result["kb-a"]["done_count"] == 5
assert result["kb-a"]["fail_count"] == 1
assert result["kb-a"]["unstart_count"] == 0
assert result["kb-a"]["running_count"] == 0
assert result["kb-a"]["cancel_count"] == 0
assert result["kb-b"]["unstart_count"] == 7
assert result["kb-b"]["done_count"] == 2
assert result["kb-b"]["fail_count"] == 0
# ------------------------------------------------------------------
# An unrecognised run value must be silently ignored
# ------------------------------------------------------------------
def test_unknown_run_value_ignored(self, call_with_rows):
rows = [
{"kb_id": "kb-1", "run": "9", "cnt": 99}, # "9" is not a TaskStatus
{"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 4},
]
result = call_with_rows(rows=rows, kb_ids=["kb-1"])
assert result["kb-1"]["done_count"] == 4
assert all(
result["kb-1"][f] == 0
for f in _ALL_STATUS_FIELDS - {"done_count"}
)
# ------------------------------------------------------------------
# A row whose kb_id was NOT requested must not appear in the output
# ------------------------------------------------------------------
def test_row_with_unrequested_kb_id_is_filtered_out(self, call_with_rows):
rows = [
{"kb_id": "kb-requested", "run": TaskStatus.DONE.value, "cnt": 3},
{"kb_id": "kb-unexpected", "run": TaskStatus.DONE.value, "cnt": 100},
]
result = call_with_rows(rows=rows, kb_ids=["kb-requested"])
assert "kb-unexpected" not in result
assert result["kb-requested"]["done_count"] == 3
# ------------------------------------------------------------------
# cnt values must be treated as integers regardless of DB type hints
# ------------------------------------------------------------------
def test_cnt_is_cast_to_int(self, call_with_rows):
rows = [
{"kb_id": "kb-1", "run": TaskStatus.RUNNING.value, "cnt": "7"},
]
result = call_with_rows(rows=rows, kb_ids=["kb-1"])
assert result["kb-1"]["running_count"] == 7
assert isinstance(result["kb-1"]["running_count"], int)
# ------------------------------------------------------------------
# run value stored as integer in DB (some adapters may omit str cast)
# ------------------------------------------------------------------
def test_run_value_as_integer_is_handled(self, call_with_rows):
rows = [
{"kb_id": "kb-1", "run": int(TaskStatus.DONE.value), "cnt": 5},
]
result = call_with_rows(rows=rows, kb_ids=["kb-1"])
assert result["kb-1"]["done_count"] == 5
# ------------------------------------------------------------------
# All five status fields are initialised to 0 even when no rows exist
# ------------------------------------------------------------------
def test_all_five_fields_initialised_to_zero(self, call_with_rows):
result = call_with_rows(rows=[], kb_ids=["kb-empty"])
assert result["kb-empty"] == {
"unstart_count": 0,
"running_count": 0,
"cancel_count": 0,
"done_count": 0,
"fail_count": 0,
}
# ------------------------------------------------------------------
# Multiple kb_ids in the input all should appear in the result
# even when no documents exist for some of them
# ------------------------------------------------------------------
def test_requested_kb_ids_all_present_in_result(self, call_with_rows):
rows = [
{"kb_id": "kb-with-data", "run": TaskStatus.DONE.value, "cnt": 1},
]
result = call_with_rows(
rows=rows, kb_ids=["kb-with-data", "kb-empty-1", "kb-empty-2"]
)
assert set(result.keys()) == {"kb-with-data", "kb-empty-1", "kb-empty-2"}
assert result["kb-empty-1"] == {f: 0 for f in _ALL_STATUS_FIELDS}
assert result["kb-empty-2"] == {f: 0 for f in _ALL_STATUS_FIELDS}
# ------------------------------------------------------------------
# SCHEDULE (run=="5") is not mapped must be silently ignored
# ------------------------------------------------------------------
def test_schedule_status_is_not_mapped(self, call_with_rows):
rows = [
{"kb_id": "kb-1", "run": TaskStatus.SCHEDULE.value, "cnt": 3},
{"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 2},
]
result = call_with_rows(rows=rows, kb_ids=["kb-1"])
assert result["kb-1"]["done_count"] == 2
# SCHEDULE is not a tracked bucket
assert "schedule_count" not in result["kb-1"]
assert all(
result["kb-1"][f] == 0
for f in _ALL_STATUS_FIELDS - {"done_count"}
)