mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-03 00:37:48 +08:00
Feat: Support get aggregated parsing status to dataset via the API (#13481)
### What problem does this PR solve? Support getting aggregated parsing status to dataset via the API Issue: #12810 ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: heyang.why <heyang.why@alibaba-inc.com>
This commit is contained in:
@ -28,8 +28,7 @@ from peewee import fn, Case, JOIN
|
||||
|
||||
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
||||
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES, FileType, UserTenantRole, CanvasCategory
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
|
||||
User
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, User
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -78,24 +77,21 @@ class DocumentService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, id, name, suffix=None, run = None, doc_ids=None):
|
||||
def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name, suffix=None, run=None, doc_ids=None):
|
||||
fields = cls.get_cls_model_fields()
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
|
||||
.join(File, on = (File.id == File2Document.file_id))\
|
||||
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
|
||||
docs = (
|
||||
cls.model.select(*[*fields, UserCanvas.title])
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))
|
||||
.join(File, on=(File.id == File2Document.file_id))
|
||||
.join(UserCanvas, on=((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
)
|
||||
if id:
|
||||
docs = docs.where(
|
||||
cls.model.id == id)
|
||||
docs = docs.where(cls.model.id == id)
|
||||
if name:
|
||||
docs = docs.where(
|
||||
cls.model.name == name
|
||||
)
|
||||
docs = docs.where(cls.model.name == name)
|
||||
if keywords:
|
||||
docs = docs.where(
|
||||
fn.LOWER(cls.model.name).contains(keywords.lower())
|
||||
)
|
||||
docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
if doc_ids:
|
||||
docs = docs.where(cls.model.id.in_(doc_ids))
|
||||
if suffix:
|
||||
@ -120,6 +116,7 @@ class DocumentService(CommonService):
|
||||
@DB.connection_context()
|
||||
def check_doc_health(cls, tenant_id: str, filename):
|
||||
import os
|
||||
|
||||
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
|
||||
raise RuntimeError("Exceed the maximum file number of a free user!")
|
||||
@ -211,13 +208,14 @@ class DocumentService(CommonService):
|
||||
"""
|
||||
fields = cls.get_cls_model_fields()
|
||||
if keywords:
|
||||
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
query = (
|
||||
cls.model.select(*fields)
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))
|
||||
.join(File, on=(File.id == File2Document.file_id))
|
||||
.where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
|
||||
)
|
||||
else:
|
||||
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
||||
|
||||
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
||||
|
||||
if run_status:
|
||||
query = query.where(cls.model.run.in_(run_status))
|
||||
@ -272,14 +270,60 @@ class DocumentService(CommonService):
|
||||
"metadata": metadata_counter,
|
||||
}, total
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_parsing_status_by_kb_ids(cls, kb_ids: list[str]) -> dict[str, dict[str, int]]:
|
||||
"""Return aggregated document parsing status counts grouped by dataset (kb_id).
|
||||
|
||||
For each kb_id, counts documents in each run-status bucket:
|
||||
- unstart_count (run == "0")
|
||||
- running_count (run == "1")
|
||||
- cancel_count (run == "2")
|
||||
- done_count (run == "3")
|
||||
- fail_count (run == "4")
|
||||
|
||||
Returns a dict keyed by kb_id, e.g.
|
||||
{"kb-abc": {"unstart_count": 10, "running_count": 2, ...}, ...}
|
||||
"""
|
||||
if not kb_ids:
|
||||
return {}
|
||||
|
||||
status_field_map = {
|
||||
TaskStatus.UNSTART.value: "unstart_count",
|
||||
TaskStatus.RUNNING.value: "running_count",
|
||||
TaskStatus.CANCEL.value: "cancel_count",
|
||||
TaskStatus.DONE.value: "done_count",
|
||||
TaskStatus.FAIL.value: "fail_count",
|
||||
}
|
||||
|
||||
empty_status = {v: 0 for v in status_field_map.values()}
|
||||
result: dict[str, dict[str, int]] = {kb_id: dict(empty_status) for kb_id in kb_ids}
|
||||
|
||||
rows = (
|
||||
cls.model.select(
|
||||
cls.model.kb_id,
|
||||
cls.model.run,
|
||||
fn.COUNT(cls.model.id).alias("cnt"),
|
||||
)
|
||||
.where(cls.model.kb_id.in_(kb_ids))
|
||||
.group_by(cls.model.kb_id, cls.model.run)
|
||||
.dicts()
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
kb_id = row["kb_id"]
|
||||
run_val = str(row["run"])
|
||||
field_name = status_field_map.get(run_val)
|
||||
if field_name and kb_id in result:
|
||||
result[kb_id][field_name] = int(row["cnt"])
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def count_by_kb_id(cls, kb_id, keywords, run_status, types):
|
||||
if keywords:
|
||||
docs = cls.model.select().where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
|
||||
else:
|
||||
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||
|
||||
@ -295,9 +339,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]):
|
||||
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(
|
||||
cls.model.kb_id == kb_id
|
||||
)
|
||||
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(cls.model.kb_id == kb_id)
|
||||
|
||||
if keywords:
|
||||
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
@ -329,12 +371,8 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_docs_by_creator_id(cls, creator_id):
|
||||
fields = [
|
||||
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
|
||||
]
|
||||
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.created_by == creator_id
|
||||
)
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id]
|
||||
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.created_by == creator_id)
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
@ -361,6 +399,7 @@ class DocumentService(CommonService):
|
||||
@DB.connection_context()
|
||||
def remove_document(cls, doc, tenant_id):
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of
|
||||
|
||||
if not cls.delete_document_and_update_kb_counts(doc.id):
|
||||
return True
|
||||
|
||||
@ -406,17 +445,22 @@ class DocumentService(CommonService):
|
||||
# Cleanup knowledge graph references (non-critical, log and continue)
|
||||
try:
|
||||
graph_source = settings.docStoreConn.get_fields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]),
|
||||
["source_id"],
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
|
||||
{"remove": {"source_id": doc.id}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
||||
{"removed_kwd": "Y"},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update(
|
||||
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
|
||||
{"remove": {"source_id": doc.id}},
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete(
|
||||
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
|
||||
|
||||
@ -428,9 +472,7 @@ class DocumentService(CommonService):
|
||||
page = 0
|
||||
page_size = 1000
|
||||
while True:
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
@ -455,71 +497,61 @@ class DocumentService(CommonService):
|
||||
Tenant.embd_id,
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
|
||||
cls.model.update_time,
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= current_timestamp() - 1000 * 600,
|
||||
cls.model.run == TaskStatus.RUNNING.value) \
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= current_timestamp() - 1000 * 600,
|
||||
cls.model.run == TaskStatus.RUNNING.value,
|
||||
)
|
||||
.order_by(cls.model.update_time.asc())
|
||||
)
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_unfinished_docs(cls):
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
|
||||
cls.model.run, cls.model.parser_id]
|
||||
unfinished_task_query = Task.select(Task.doc_id).where(
|
||||
(Task.progress >= 0) & (Task.progress < 1)
|
||||
)
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id]
|
||||
unfinished_task_query = Task.select(Task.doc_id).where((Task.progress >= 0) & (Task.progress < 1))
|
||||
|
||||
docs = cls.model.select(*fields) \
|
||||
.where(
|
||||
docs = cls.model.select(*fields).where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
|
||||
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
|
||||
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
|
||||
(((cls.model.progress < 1) & (cls.model.progress > 0)) | (cls.model.id.in_(unfinished_task_query))),
|
||||
) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
|
||||
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
||||
chunk_num=cls.model.chunk_num + chunk_num,
|
||||
process_duration=cls.model.process_duration + duration).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
num = (
|
||||
cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duration=cls.model.process_duration + duration)
|
||||
.where(cls.model.id == doc_id)
|
||||
.execute()
|
||||
)
|
||||
if num == 0:
|
||||
logging.warning("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num +
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num +
|
||||
chunk_num).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
|
||||
num = cls.model.update(token_num=cls.model.token_num - token_num,
|
||||
chunk_num=cls.model.chunk_num - chunk_num,
|
||||
process_duration=cls.model.process_duration + duration).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
num = (
|
||||
cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duration=cls.model.process_duration + duration)
|
||||
.where(cls.model.id == doc_id)
|
||||
.execute()
|
||||
)
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
chunk_num
|
||||
).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
raise LookupError("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@ -551,17 +583,13 @@ class DocumentService(CommonService):
|
||||
doc = cls.model.get_by_id(doc_id)
|
||||
assert doc, "Can't fine document in database."
|
||||
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
doc.token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
doc.chunk_num,
|
||||
doc_num=Knowledgebase.doc_num - 1
|
||||
).where(
|
||||
Knowledgebase.id == doc.kb_id).execute()
|
||||
num = (
|
||||
Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1)
|
||||
.where(Knowledgebase.id == doc.kb_id)
|
||||
.execute()
|
||||
)
|
||||
return num
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def clear_chunk_num_when_rerun(cls, doc_id):
|
||||
@ -578,15 +606,10 @@ class DocumentService(CommonService):
|
||||
)
|
||||
return num
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return None
|
||||
@ -604,11 +627,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id_by_name(cls, name):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return None
|
||||
@ -617,12 +636,13 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible(cls, doc_id, user_id):
|
||||
docs = cls.model.select(
|
||||
cls.model.id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)
|
||||
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
|
||||
docs = (
|
||||
cls.model.select(cls.model.id)
|
||||
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
|
||||
.join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id))
|
||||
.where(cls.model.id == doc_id, UserTenant.user_id == user_id)
|
||||
.paginate(0, 1)
|
||||
)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
@ -631,18 +651,13 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible4deletion(cls, doc_id, user_id):
|
||||
docs = cls.model.select(cls.model.id
|
||||
).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)
|
||||
).join(
|
||||
UserTenant, on=(
|
||||
(UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
|
||||
).where(
|
||||
cls.model.id == doc_id,
|
||||
UserTenant.status == StatusEnum.VALID.value,
|
||||
((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))
|
||||
).paginate(0, 1)
|
||||
docs = (
|
||||
cls.model.select(cls.model.id)
|
||||
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
|
||||
.join(UserTenant, on=((UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id)))
|
||||
.where(cls.model.id == doc_id, UserTenant.status == StatusEnum.VALID.value, ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER)))
|
||||
.paginate(0, 1)
|
||||
)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
@ -651,11 +666,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_embd_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.embd_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return None
|
||||
@ -664,11 +675,9 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_embd_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_embd_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = (
|
||||
cls.model.select(Knowledgebase.tenant_embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return None
|
||||
@ -705,8 +714,7 @@ class DocumentService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_doc_id_by_doc_name(cls, doc_name):
|
||||
fields = [cls.model.id]
|
||||
doc_id = cls.model.select(*fields) \
|
||||
.where(cls.model.name == doc_name)
|
||||
doc_id = cls.model.select(*fields).where(cls.model.name == doc_name)
|
||||
doc_id = doc_id.dicts()
|
||||
if not doc_id:
|
||||
return None
|
||||
@ -725,8 +733,7 @@ class DocumentService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_thumbnails(cls, docids):
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
|
||||
return list(cls.model.select(
|
||||
*fields).where(cls.model.id.in_(docids)).dicts())
|
||||
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -755,9 +762,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_count(cls, tenant_id):
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase,
|
||||
on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
Knowledgebase.tenant_id == tenant_id)
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id)
|
||||
return len(docs)
|
||||
|
||||
@classmethod
|
||||
@ -768,7 +773,7 @@ class DocumentService(CommonService):
|
||||
"process_begin_at": get_format_time(),
|
||||
}
|
||||
if not keep_progress:
|
||||
info["progress"] = random.random() * 1 / 100.
|
||||
info["progress"] = random.random() * 1 / 100.0
|
||||
info["run"] = TaskStatus.RUNNING.value
|
||||
# keep the doc in DONE state when keep_progress=True for GraphRAG, RAPTOR and Mindmap tasks
|
||||
|
||||
@ -781,19 +786,17 @@ class DocumentService(CommonService):
|
||||
|
||||
cls._sync_progress(docs)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress_immediately(cls, docs:list[dict]):
|
||||
def update_progress_immediately(cls, docs: list[dict]):
|
||||
if not docs:
|
||||
return
|
||||
|
||||
cls._sync_progress(docs)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def _sync_progress(cls, docs:list[dict]):
|
||||
def _sync_progress(cls, docs: list[dict]):
|
||||
from api.db.services.task_service import TaskService
|
||||
|
||||
for d in docs:
|
||||
@ -841,27 +844,18 @@ class DocumentService(CommonService):
|
||||
# fallback
|
||||
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
|
||||
|
||||
info = {
|
||||
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
|
||||
"run": status}
|
||||
info = {"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), "run": status}
|
||||
if prg != 0 and not freeze_progress:
|
||||
info["progress"] = prg
|
||||
if msg:
|
||||
info["progress_msg"] = msg
|
||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
|
||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..." % get_queue_length(priority)
|
||||
else:
|
||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
info["progress_msg"] = "%d tasks are ahead in the queue..." % get_queue_length(priority)
|
||||
info["update_time"] = current_timestamp()
|
||||
info["update_date"] = get_format_time()
|
||||
(
|
||||
cls.model.update(info)
|
||||
.where(
|
||||
(cls.model.id == d["id"])
|
||||
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
(cls.model.update(info).where((cls.model.id == d["id"]) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))).execute())
|
||||
except Exception as e:
|
||||
if str(e).find("'0'") < 0:
|
||||
logging.exception("fetch task exception")
|
||||
@ -875,7 +869,7 @@ class DocumentService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_all_kb_doc_count(cls):
|
||||
result = {}
|
||||
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
|
||||
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias("count")).group_by(cls.model.kb_id)
|
||||
for row in rows:
|
||||
result[row.kb_id] = row.count
|
||||
return result
|
||||
@ -890,33 +884,19 @@ class DocumentService(CommonService):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
|
||||
# cancelled: run == "2"
|
||||
cancelled = (
|
||||
cls.model.select(fn.COUNT(1))
|
||||
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
|
||||
.scalar()
|
||||
)
|
||||
downloaded = (
|
||||
cls.model.select(fn.COUNT(1))
|
||||
.where(
|
||||
cls.model.kb_id == kb_id,
|
||||
cls.model.source_type != "local"
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
cancelled = cls.model.select(fn.COUNT(1)).where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)).scalar()
|
||||
downloaded = cls.model.select(fn.COUNT(1)).where(cls.model.kb_id == kb_id, cls.model.source_type != "local").scalar()
|
||||
|
||||
row = (
|
||||
cls.model.select(
|
||||
# finished: progress == 1
|
||||
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"),
|
||||
|
||||
# failed: progress == -1
|
||||
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"),
|
||||
|
||||
# processing: 0 <= progress < 1
|
||||
fn.COALESCE(
|
||||
fn.SUM(
|
||||
@ -931,24 +911,15 @@ class DocumentService(CommonService):
|
||||
0,
|
||||
).alias("processing"),
|
||||
)
|
||||
.where(
|
||||
(cls.model.kb_id == kb_id)
|
||||
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))
|
||||
)
|
||||
.where((cls.model.kb_id == kb_id) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL)))
|
||||
.dicts()
|
||||
.get()
|
||||
)
|
||||
|
||||
return {
|
||||
"processing": int(row["processing"]),
|
||||
"finished": int(row["finished"]),
|
||||
"failed": int(row["failed"]),
|
||||
"cancelled": int(cancelled),
|
||||
"downloaded": int(downloaded)
|
||||
}
|
||||
return {"processing": int(row["processing"]), "finished": int(row["finished"]), "failed": int(row["failed"]), "cancelled": int(cancelled), "downloaded": int(downloaded)}
|
||||
|
||||
@classmethod
|
||||
def run(cls, tenant_id:str, doc:dict, kb_table_num_map:dict):
|
||||
def run(cls, tenant_id: str, doc: dict, kb_table_num_map: dict):
|
||||
from api.db.services.task_service import queue_dataflow, queue_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
|
||||
@ -990,7 +961,7 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
|
||||
"begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
|
||||
@ -1032,8 +1003,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not dia.kb_ids:
|
||||
raise LookupError("No dataset associated with this conversation. "
|
||||
"Please add a dataset before uploading documents")
|
||||
raise LookupError("No dataset associated with this conversation. Please add a dataset before uploading documents")
|
||||
kb_id = dia.kb_ids[0]
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
@ -1050,12 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
FACTORY = {
|
||||
ParserType.PRESENTATION.value: presentation,
|
||||
ParserType.PICTURE.value: picture,
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email
|
||||
}
|
||||
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
threads = []
|
||||
@ -1063,22 +1028,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
for d, blob in files:
|
||||
doc_nm[d["id"]] = d["name"]
|
||||
for d, blob in files:
|
||||
kwargs = {
|
||||
"callback": dummy,
|
||||
"parser_config": parser_config,
|
||||
"from_page": 0,
|
||||
"to_page": 100000,
|
||||
"tenant_id": kb.tenant_id,
|
||||
"lang": kb.language
|
||||
}
|
||||
kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
|
||||
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
|
||||
|
||||
for (docinfo, _), th in zip(files, threads):
|
||||
docs = []
|
||||
doc = {
|
||||
"doc_id": docinfo["id"],
|
||||
"kb_id": [kb.id]
|
||||
}
|
||||
doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
|
||||
for ck in th.result():
|
||||
d = deepcopy(doc)
|
||||
d.update(ck)
|
||||
@ -1093,7 +1048,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
if isinstance(d["image"], bytes):
|
||||
output_buffer = BytesIO(d["image"])
|
||||
else:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
d["image"].save(output_buffer, format="JPEG")
|
||||
|
||||
settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(kb.id, d["id"])
|
||||
@ -1110,9 +1065,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
nonlocal embd_mdl, chunk_counts, token_counts
|
||||
vectors = []
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
|
||||
vts, c = embd_mdl.encode(cnts[i : i + batch_size])
|
||||
vectors.extend(vts.tolist())
|
||||
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
|
||||
chunk_counts[doc_id] += len(cnts[i : i + batch_size])
|
||||
token_counts[doc_id] += c
|
||||
return vectors
|
||||
|
||||
@ -1127,22 +1082,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
|
||||
if parser_ids[doc_id] != ParserType.PICTURE.value:
|
||||
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
|
||||
mindmap = MindMapExtractor(llm_bdl)
|
||||
try:
|
||||
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
|
||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
cks.append({
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc_id,
|
||||
"kb_id": [kb.id],
|
||||
"docnm_kwd": doc_nm[doc_id],
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
|
||||
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
|
||||
"content_with_weight": mind_map,
|
||||
"knowledge_graph_kwd": "mind_map"
|
||||
})
|
||||
cks.append(
|
||||
{
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc_id,
|
||||
"kb_id": [kb.id],
|
||||
"docnm_kwd": doc_nm[doc_id],
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
|
||||
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
|
||||
"content_with_weight": mind_map,
|
||||
"knowledge_graph_kwd": "mind_map",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Mind map generation error")
|
||||
|
||||
@ -1156,9 +1114,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
if not settings.docStoreConn.index_exist(idxnm, kb_id):
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
||||
return [d["id"] for d, _ in files]
|
||||
|
||||
Reference in New Issue
Block a user