Feat: Support get aggregated parsing status to dataset via the API (#13481)

### What problem does this PR solve?

Support getting aggregated parsing status to dataset via the API

Issue: #12810

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: heyang.why <heyang.why@alibaba-inc.com>
This commit is contained in:
Heyang Wang
2026-03-10 18:05:45 +08:00
committed by GitHub
parent 68a623154a
commit 08f83ff331
7 changed files with 654 additions and 309 deletions

View File

@ -28,8 +28,7 @@ from peewee import fn, Case, JOIN
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES, FileType, UserTenantRole, CanvasCategory
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
User
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, User
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -78,24 +77,21 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id, name, suffix=None, run = None, doc_ids=None):
def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name, suffix=None, run=None, doc_ids=None):
fields = cls.get_cls_model_fields()
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
.join(File, on = (File.id == File2Document.file_id))\
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
docs = (
cls.model.select(*[*fields, UserCanvas.title])
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.join(UserCanvas, on=((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)
.where(cls.model.kb_id == kb_id)
)
if id:
docs = docs.where(
cls.model.id == id)
docs = docs.where(cls.model.id == id)
if name:
docs = docs.where(
cls.model.name == name
)
docs = docs.where(cls.model.name == name)
if keywords:
docs = docs.where(
fn.LOWER(cls.model.name).contains(keywords.lower())
)
docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if doc_ids:
docs = docs.where(cls.model.id.in_(doc_ids))
if suffix:
@ -120,6 +116,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def check_doc_health(cls, tenant_id: str, filename):
import os
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
raise RuntimeError("Exceed the maximum file number of a free user!")
@ -211,13 +208,14 @@ class DocumentService(CommonService):
"""
fields = cls.get_cls_model_fields()
if keywords:
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
query = (
cls.model.select(*fields)
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
)
else:
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
if run_status:
query = query.where(cls.model.run.in_(run_status))
@ -272,14 +270,60 @@ class DocumentService(CommonService):
"metadata": metadata_counter,
}, total
@classmethod
@DB.connection_context()
def get_parsing_status_by_kb_ids(cls, kb_ids: list[str]) -> dict[str, dict[str, int]]:
"""Return aggregated document parsing status counts grouped by dataset (kb_id).
For each kb_id, counts documents in each run-status bucket:
- unstart_count (run == "0")
- running_count (run == "1")
- cancel_count (run == "2")
- done_count (run == "3")
- fail_count (run == "4")
Returns a dict keyed by kb_id, e.g.
{"kb-abc": {"unstart_count": 10, "running_count": 2, ...}, ...}
"""
if not kb_ids:
return {}
status_field_map = {
TaskStatus.UNSTART.value: "unstart_count",
TaskStatus.RUNNING.value: "running_count",
TaskStatus.CANCEL.value: "cancel_count",
TaskStatus.DONE.value: "done_count",
TaskStatus.FAIL.value: "fail_count",
}
empty_status = {v: 0 for v in status_field_map.values()}
result: dict[str, dict[str, int]] = {kb_id: dict(empty_status) for kb_id in kb_ids}
rows = (
cls.model.select(
cls.model.kb_id,
cls.model.run,
fn.COUNT(cls.model.id).alias("cnt"),
)
.where(cls.model.kb_id.in_(kb_ids))
.group_by(cls.model.kb_id, cls.model.run)
.dicts()
)
for row in rows:
kb_id = row["kb_id"]
run_val = str(row["run"])
field_name = status_field_map.get(run_val)
if field_name and kb_id in result:
result[kb_id][field_name] = int(row["cnt"])
return result
@classmethod
@DB.connection_context()
def count_by_kb_id(cls, kb_id, keywords, run_status, types):
if keywords:
docs = cls.model.select().where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
@ -295,9 +339,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]):
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(
cls.model.kb_id == kb_id
)
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(cls.model.kb_id == kb_id)
if keywords:
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
@ -329,12 +371,8 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_all_docs_by_creator_id(cls, creator_id):
fields = [
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
cls.model.created_by == creator_id
)
fields = [cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.created_by == creator_id)
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
@ -361,6 +399,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
from api.db.services.task_service import TaskService, cancel_all_task_of
if not cls.delete_document_and_update_kb_counts(doc.id):
return True
@ -406,17 +445,22 @@ class DocumentService(CommonService):
# Cleanup knowledge graph references (non-critical, log and continue)
try:
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]),
["source_id"],
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id),
doc.kb_id,
)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id),
doc.kb_id,
)
except Exception as e:
logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
@ -428,9 +472,7 @@ class DocumentService(CommonService):
page = 0
page_size = 1000
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id])
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
@ -455,71 +497,61 @@ class DocumentService(CommonService):
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
cls.model.update_time,
]
docs = (
cls.model.select(*fields)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value) \
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value,
)
.order_by(cls.model.update_time.asc())
)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run, cls.model.parser_id]
unfinished_task_query = Task.select(Task.doc_id).where(
(Task.progress >= 0) & (Task.progress < 1)
)
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id]
unfinished_task_query = Task.select(Task.doc_id).where((Task.progress >= 0) & (Task.progress < 1))
docs = cls.model.select(*fields) \
.where(
docs = cls.model.select(*fields).where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
(((cls.model.progress < 1) & (cls.model.progress > 0)) | (cls.model.id.in_(unfinished_task_query))),
) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duration=cls.model.process_duration + duration).where(
cls.model.id == doc_id).execute()
num = (
cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duration=cls.model.process_duration + duration)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
logging.warning("Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = cls.model.update(token_num=cls.model.token_num - token_num,
chunk_num=cls.model.chunk_num - chunk_num,
process_duration=cls.model.process_duration + duration).where(
cls.model.id == doc_id).execute()
num = (
cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duration=cls.model.process_duration + duration)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@ -551,17 +583,13 @@ class DocumentService(CommonService):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1
).where(
Knowledgebase.id == doc.kb_id).execute()
num = (
Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1)
.where(Knowledgebase.id == doc.kb_id)
.execute()
)
return num
@classmethod
@DB.connection_context()
def clear_chunk_num_when_rerun(cls, doc_id):
@ -578,15 +606,10 @@ class DocumentService(CommonService):
)
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
@ -604,11 +627,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_tenant_id_by_name(cls, name):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
@ -617,12 +636,13 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = (
cls.model.select(cls.model.id)
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
.join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id))
.where(cls.model.id == doc_id, UserTenant.user_id == user_id)
.paginate(0, 1)
)
docs = docs.dicts()
if not docs:
return False
@ -631,18 +651,13 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(cls.model.id
).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(
UserTenant, on=(
(UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
).where(
cls.model.id == doc_id,
UserTenant.status == StatusEnum.VALID.value,
((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))
).paginate(0, 1)
docs = (
cls.model.select(cls.model.id)
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
.join(UserTenant, on=((UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id)))
.where(cls.model.id == doc_id, UserTenant.status == StatusEnum.VALID.value, ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER)))
.paginate(0, 1)
)
docs = docs.dicts()
if not docs:
return False
@ -651,11 +666,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
@ -664,11 +675,9 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_tenant_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = (
cls.model.select(Knowledgebase.tenant_embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
)
docs = docs.dicts()
if not docs:
return None
@ -705,8 +714,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
doc_id = cls.model.select(*fields).where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return None
@ -725,8 +733,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
@ -755,9 +762,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase,
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id)
return len(docs)
@classmethod
@ -768,7 +773,7 @@ class DocumentService(CommonService):
"process_begin_at": get_format_time(),
}
if not keep_progress:
info["progress"] = random.random() * 1 / 100.
info["progress"] = random.random() * 1 / 100.0
info["run"] = TaskStatus.RUNNING.value
# keep the doc in DONE state when keep_progress=True for GraphRAG, RAPTOR and Mindmap tasks
@ -781,19 +786,17 @@ class DocumentService(CommonService):
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def update_progress_immediately(cls, docs:list[dict]):
def update_progress_immediately(cls, docs: list[dict]):
if not docs:
return
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def _sync_progress(cls, docs:list[dict]):
def _sync_progress(cls, docs: list[dict]):
from api.db.services.task_service import TaskService
for d in docs:
@ -841,27 +844,18 @@ class DocumentService(CommonService):
# fallback
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
info = {
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
"run": status}
info = {"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), "run": status}
if prg != 0 and not freeze_progress:
info["progress"] = prg
if msg:
info["progress_msg"] = msg
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
info["progress_msg"] += "\n%d tasks are ahead in the queue..." % get_queue_length(priority)
else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
info["progress_msg"] = "%d tasks are ahead in the queue..." % get_queue_length(priority)
info["update_time"] = current_timestamp()
info["update_date"] = get_format_time()
(
cls.model.update(info)
.where(
(cls.model.id == d["id"])
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
)
.execute()
)
(cls.model.update(info).where((cls.model.id == d["id"]) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))).execute())
except Exception as e:
if str(e).find("'0'") < 0:
logging.exception("fetch task exception")
@ -875,7 +869,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_all_kb_doc_count(cls):
result = {}
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias("count")).group_by(cls.model.kb_id)
for row in rows:
result[row.kb_id] = row.count
return result
@ -890,33 +884,19 @@ class DocumentService(CommonService):
pass
return False
@classmethod
@DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
# cancelled: run == "2"
cancelled = (
cls.model.select(fn.COUNT(1))
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
.scalar()
)
downloaded = (
cls.model.select(fn.COUNT(1))
.where(
cls.model.kb_id == kb_id,
cls.model.source_type != "local"
)
.scalar()
)
cancelled = cls.model.select(fn.COUNT(1)).where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)).scalar()
downloaded = cls.model.select(fn.COUNT(1)).where(cls.model.kb_id == kb_id, cls.model.source_type != "local").scalar()
row = (
cls.model.select(
# finished: progress == 1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"),
# failed: progress == -1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"),
# processing: 0 <= progress < 1
fn.COALESCE(
fn.SUM(
@ -931,24 +911,15 @@ class DocumentService(CommonService):
0,
).alias("processing"),
)
.where(
(cls.model.kb_id == kb_id)
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))
)
.where((cls.model.kb_id == kb_id) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL)))
.dicts()
.get()
)
return {
"processing": int(row["processing"]),
"finished": int(row["finished"]),
"failed": int(row["failed"]),
"cancelled": int(cancelled),
"downloaded": int(downloaded)
}
return {"processing": int(row["processing"]), "finished": int(row["finished"]), "failed": int(row["failed"]), "cancelled": int(cancelled), "downloaded": int(downloaded)}
@classmethod
def run(cls, tenant_id:str, doc:dict, kb_table_num_map:dict):
def run(cls, tenant_id: str, doc: dict, kb_table_num_map: dict):
from api.db.services.task_service import queue_dataflow, queue_tasks
from api.db.services.file2document_service import File2DocumentService
@ -990,7 +961,7 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d
"from_page": 100000000,
"to_page": 100000000,
"task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
@ -1032,8 +1003,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
e, dia = DialogService.get_by_id(conv.dialog_id)
if not dia.kb_ids:
raise LookupError("No dataset associated with this conversation. "
"Please add a dataset before uploading documents")
raise LookupError("No dataset associated with this conversation. Please add a dataset before uploading documents")
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
@ -1050,12 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
def dummy(prog=None, msg=""):
pass
FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
@ -1063,22 +1028,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
for d, blob in files:
doc_nm[d["id"]] = d["name"]
for d, blob in files:
kwargs = {
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": kb.tenant_id,
"lang": kb.language
}
kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
for (docinfo, _), th in zip(files, threads):
docs = []
doc = {
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
@ -1093,7 +1048,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
d["image"].save(output_buffer, format="JPEG")
settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["id"])
@ -1110,9 +1065,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
nonlocal embd_mdl, chunk_counts, token_counts
vectors = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vts, c = embd_mdl.encode(cnts[i : i + batch_size])
vectors.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
chunk_counts[doc_id] += len(cnts[i : i + batch_size])
token_counts[doc_id] += c
return vectors
@ -1127,22 +1082,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if parser_ids[doc_id] != ParserType.PICTURE.value:
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map"
})
cks.append(
{
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map",
}
)
except Exception:
logging.exception("Mind map generation error")
@ -1156,9 +1114,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d, _ in files]