Files
ragflow/api/db/services/document_service.py
box4wangjing 292b0b8bce chore: fix some comments to improve readability (#14756)
### What problem does this PR solve?

fix some comments to improve readability

### Type of change

- [x] Documentation Update

---------

Signed-off-by: box4wangjing <box4wangjing@outlook.com>
2026-05-11 16:48:48 +08:00

1051 lines
42 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import random
from datetime import datetime
import xxhash
from peewee import fn, Case, JOIN
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES, FileType, UserTenantRole, CanvasCategory
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, User
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService, retry_deadlock_operation
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.doc_metadata_service import DocMetadataService
from common import settings
from common.constants import ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME, MAXIMUM_TASK_PAGE_NUMBER
from common.doc_store.doc_store_base import OrderByExpr
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, get_format_time
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
class DocumentService(CommonService):
model = Document
@classmethod
def get_cls_model_fields(cls):
return [
cls.model.id,
cls.model.thumbnail,
cls.model.kb_id,
cls.model.parser_id,
cls.model.pipeline_id,
cls.model.parser_config,
cls.model.source_type,
cls.model.type,
cls.model.created_by,
cls.model.name,
cls.model.location,
cls.model.size,
cls.model.token_num,
cls.model.chunk_num,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.suffix,
cls.model.run,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name, suffix=None, run=None, doc_ids=None):
fields = cls.get_cls_model_fields()
docs = (
cls.model.select(*[*fields, UserCanvas.title])
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.join(UserCanvas, on=((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)
.where(cls.model.kb_id == kb_id)
)
if id:
docs = docs.where(cls.model.id == id)
if name:
docs = docs.where(cls.model.name == name)
if keywords:
docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if doc_ids:
docs = docs.where(cls.model.id.in_(doc_ids))
if suffix:
docs = docs.where(cls.model.suffix.in_(suffix))
if run:
docs = docs.where(cls.model.run.in_(run))
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
count = docs.count()
docs = docs.paginate(page_number, items_per_page)
docs_list = list(docs.dicts())
doc_ids_on_page = [doc["id"] for doc in docs_list]
metadata_map = DocMetadataService.get_metadata_for_documents(doc_ids_on_page, kb_id) if doc_ids_on_page else {}
for doc in docs_list:
doc["meta_fields"] = metadata_map.get(doc["id"], {})
return docs_list, count
@classmethod
@DB.connection_context()
def check_doc_health(cls, tenant_id: str, filename):
import os
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
raise RuntimeError("Exceed the maximum file number of a free user!")
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
raise RuntimeError("Exceed the maximum length of file name!")
return True
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, name=None, doc_ids=None, return_empty_metadata=False):
fields = cls.get_cls_model_fields()
if keywords:
docs = (
cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)
.where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
)
else:
docs = (
cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)
.join(File, on=(File.id == File2Document.file_id))
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)
.where(cls.model.kb_id == kb_id)
)
if doc_ids:
docs = docs.where(cls.model.id.in_(doc_ids))
if run_status:
docs = docs.where(cls.model.run.in_(run_status))
if types:
docs = docs.where(cls.model.type.in_(types))
if suffix:
docs = docs.where(cls.model.suffix.in_(suffix))
if name:
docs = docs.where(cls.model.name == name)
if return_empty_metadata:
metadata_map = DocMetadataService.get_metadata_for_documents(None, kb_id)
doc_ids_with_metadata = set(metadata_map.keys())
if doc_ids_with_metadata:
docs = docs.where(cls.model.id.not_in(doc_ids_with_metadata))
count = docs.count()
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
docs = docs.paginate(page_number, items_per_page)
docs_list = list(docs.dicts())
if return_empty_metadata:
for doc in docs_list:
doc["meta_fields"] = {}
else:
doc_ids_on_page = [doc["id"] for doc in docs_list]
metadata_map = DocMetadataService.get_metadata_for_documents(doc_ids_on_page, kb_id) if doc_ids_on_page else {}
for doc in docs_list:
doc["meta_fields"] = metadata_map.get(doc["id"], {})
return docs_list, count
@classmethod
@DB.connection_context()
def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix):
"""
returns:
{
"suffix": {
"ppt": 1,
"doxc": 2
},
"run_status": {
"1": 2,
"2": 2
}
"metadata": {
"key1": {
"key1_value1": 1,
"key1_value2": 2,
},
"key2": {
"key2_value1": 2,
"key2_value2": 1,
},
}
}, total
where "1" => RUNNING, "2" => CANCEL
"""
fields = cls.get_cls_model_fields()
if keywords:
query = (
cls.model.select(*fields)
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
)
else:
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
if run_status:
query = query.where(cls.model.run.in_(run_status))
if types:
query = query.where(cls.model.type.in_(types))
if suffix:
query = query.where(cls.model.suffix.in_(suffix))
rows = query.select(cls.model.run, cls.model.suffix, cls.model.id)
total = rows.count()
suffix_counter = {}
run_status_counter = {}
metadata_counter = {}
empty_metadata_count = 0
doc_ids = [row.id for row in rows]
metadata = {}
if doc_ids:
try:
metadata = DocMetadataService.get_metadata_for_documents(doc_ids, kb_id)
except Exception as e:
logging.warning(f"Failed to fetch metadata from ES/Infinity: {e}")
for row in rows:
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
meta_fields = metadata.get(row.id, {})
if not meta_fields:
empty_metadata_count += 1
continue
has_valid_meta = False
for key, value in meta_fields.items():
values = value if isinstance(value, list) else [value]
for vv in values:
if vv is None:
continue
if isinstance(vv, str) and not vv.strip():
continue
sv = str(vv)
if key not in metadata_counter:
metadata_counter[key] = {}
metadata_counter[key][sv] = metadata_counter[key].get(sv, 0) + 1
has_valid_meta = True
if not has_valid_meta:
empty_metadata_count += 1
metadata_counter["empty_metadata"] = {"true": empty_metadata_count}
return {
"suffix": suffix_counter,
"run_status": run_status_counter,
"metadata": metadata_counter,
}, total
@classmethod
@DB.connection_context()
def get_parsing_status_by_kb_ids(cls, kb_ids: list[str]) -> dict[str, dict[str, int]]:
"""Return aggregated document parsing status counts grouped by dataset (kb_id).
For each kb_id, counts documents in each run-status bucket:
- unstart_count (run == "0")
- running_count (run == "1")
- cancel_count (run == "2")
- done_count (run == "3")
- fail_count (run == "4")
Returns a dict keyed by kb_id, e.g.
{"kb-abc": {"unstart_count": 10, "running_count": 2, ...}, ...}
"""
if not kb_ids:
return {}
status_field_map = {
TaskStatus.UNSTART.value: "unstart_count",
TaskStatus.RUNNING.value: "running_count",
TaskStatus.CANCEL.value: "cancel_count",
TaskStatus.DONE.value: "done_count",
TaskStatus.FAIL.value: "fail_count",
}
empty_status = {v: 0 for v in status_field_map.values()}
result: dict[str, dict[str, int]] = {kb_id: dict(empty_status) for kb_id in kb_ids}
rows = (
cls.model.select(
cls.model.kb_id,
cls.model.run,
fn.COUNT(cls.model.id).alias("cnt"),
)
.where(cls.model.kb_id.in_(kb_ids))
.group_by(cls.model.kb_id, cls.model.run)
.dicts()
)
for row in rows:
kb_id = row["kb_id"]
run_val = str(row["run"])
field_name = status_field_map.get(run_val)
if field_name and kb_id in result:
result[kb_id][field_name] = int(row["cnt"])
return result
@classmethod
@DB.connection_context()
def count_by_kb_id(cls, kb_id, keywords, run_status, types):
if keywords:
docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
if run_status:
docs = docs.where(cls.model.run.in_(run_status))
if types:
docs = docs.where(cls.model.type.in_(types))
count = docs.count()
return count
@classmethod
@DB.connection_context()
def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]):
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(cls.model.kb_id == kb_id)
if keywords:
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if run_status:
query = query.where(cls.model.run.in_(run_status))
if types:
query = query.where(cls.model.type.in_(types))
return int(query.scalar()) or 0
@classmethod
@DB.connection_context()
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
fields = [cls.model.id, cls.model.kb_id]
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
res = []
while True:
doc_batch = docs.offset(offset).limit(limit)
_temp = list(doc_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def list_doc_headers_by_kb_and_source_type(cls, kb_id, source_type, page_size=500):
fields = [cls.model.id, cls.model.kb_id, cls.model.source_type, cls.model.name]
docs = cls.model.select(*fields).where(
cls.model.kb_id == kb_id,
cls.model.source_type == source_type,
).order_by(cls.model.create_time.asc())
offset = 0
res = []
while True:
doc_batch = docs.offset(offset).limit(page_size)
_temp = list(doc_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += page_size
return res
@classmethod
@DB.connection_context()
def list_id_content_hash_map_by_kb_and_source_type(cls, kb_id, source_type, page_size=500):
"""Return {doc_id: content_hash} for the connector's existing docs.
Used by the fingerprint-bypass path to decide which keys can skip a
re-fetch -- if the connector's listing fingerprint equals content_hash,
the body hasn't changed since the last sync.
Ordered by create_time so LIMIT/OFFSET pagination is stable under
concurrent writes; without this, page boundaries can drop or duplicate
rows and the resulting map would silently miss entries.
"""
fields = [cls.model.id, cls.model.content_hash]
docs = cls.model.select(*fields).where(
cls.model.kb_id == kb_id,
cls.model.source_type == source_type,
).order_by(cls.model.create_time.asc())
offset = 0
result: dict[str, str] = {}
while True:
batch = list(docs.offset(offset).limit(page_size).dicts())
if not batch:
break
for row in batch:
result[row["id"]] = row.get("content_hash") or ""
offset += page_size
return result
@classmethod
@DB.connection_context()
def get_all_docs_by_creator_id(cls, creator_id):
fields = [cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.created_by == creator_id)
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
res = []
while True:
doc_batch = docs.offset(offset).limit(limit)
_temp = list(doc_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def insert(cls, doc):
if not cls.save(**doc):
raise RuntimeError("Database error (Document)!")
if not KnowledgebaseService.atomic_increase_doc_num_by_id(doc["kb_id"]):
raise RuntimeError("Database error (Knowledgebase)!")
return Document(**doc)
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
from api.db.services.task_service import TaskService, cancel_all_task_of
if not cls.delete_document_and_update_kb_counts(doc.id):
return True
chunk_index_name = search.index_name(tenant_id)
chunk_index_exists = settings.docStoreConn.index_exist(chunk_index_name, doc.kb_id)
# Cancel all running tasks first using preset function in task_service.py --- set cancel flag in Redis
try:
cancel_all_task_of(doc.id)
logging.info(f"Cancelled all tasks for document {doc.id}")
except Exception as e:
logging.warning(f"Failed to cancel tasks for document {doc.id}: {e}")
# Delete tasks from database
try:
TaskService.filter_delete([Task.doc_id == doc.id])
except Exception as e:
logging.warning(f"Failed to delete tasks for document {doc.id}: {e}")
# Delete chunk images (non-critical, log and continue)
try:
if chunk_index_exists:
cls.delete_chunk_images(doc, tenant_id)
except Exception as e:
logging.warning(f"Failed to delete chunk images for document {doc.id}: {e}")
# Delete thumbnail (non-critical, log and continue)
try:
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
except Exception as e:
logging.warning(f"Failed to delete thumbnail for document {doc.id}: {e}")
# Delete chunks from doc store - this is critical, log errors
try:
settings.docStoreConn.delete({"doc_id": doc.id}, chunk_index_name, doc.kb_id)
except Exception as e:
logging.error(f"Failed to delete chunks from doc store for document {doc.id}: {e}")
# Delete document metadata (non-critical, log and continue)
try:
DocMetadataService.delete_document_metadata(doc.id, doc.kb_id, tenant_id)
except Exception as e:
logging.warning(f"Failed to delete metadata for document {doc.id}: {e}")
# Cleanup knowledge graph references (non-critical, log and continue)
try:
if chunk_index_exists:
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, chunk_index_name, [doc.kb_id]),
["source_id"],
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
settings.docStoreConn.update(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
chunk_index_name,
doc.kb_id,
)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, chunk_index_name, doc.kb_id)
settings.docStoreConn.delete(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
chunk_index_name,
doc.kb_id,
)
except Exception as e:
logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
return True
@classmethod
@DB.connection_context()
def delete_chunk_images(cls, doc, tenant_id):
page = 0
page_size = 1000
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id])
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
for cid in chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
page += 1
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls):
fields = [
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
cls.model.name,
cls.model.type,
cls.model.location,
cls.model.size,
Knowledgebase.tenant_id,
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time,
]
docs = (
cls.model.select(*fields)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value,
)
.order_by(cls.model.update_time.asc())
)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run, cls.model.parser_id]
unfinished_task_query = Task.select(Task.doc_id).where(
(Task.progress >= 0) & (Task.progress < 1)
)
docs_with_non_failed_tasks = Task.select(Task.doc_id).where(Task.progress >= 0).distinct()
docs = cls.model.select(*fields).where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
(cls.model.id.in_(unfinished_task_query)) |
((cls.model.progress == -1) & (cls.model.run == TaskStatus.FAIL.value) &
(cls.model.id.in_(docs_with_non_failed_tasks))))) # including GraphRAG/RAPTOR/Mindmap; re-sync failed docs
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = (
cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duration=cls.model.process_duration + duration)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
logging.warning("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = (
cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duration=cls.model.process_duration + duration)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@retry_deadlock_operation()
@DB.connection_context()
def delete_document_and_update_kb_counts(cls, doc_id) -> bool:
"""Atomically delete the document row and update KB counters.
Returns True if the document was deleted by this call, False if it was
already deleted by a concurrent request (idempotent).
"""
with DB.atomic():
doc = (
cls.model.select(
cls.model.id,
cls.model.kb_id,
cls.model.token_num,
cls.model.chunk_num,
)
.where(cls.model.id == doc_id)
.for_update()
.get_or_none()
)
if doc is None:
return False
deleted = cls.model.delete().where(cls.model.id == doc_id).execute()
if not deleted:
return False
Knowledgebase.update(
token_num=Knowledgebase.token_num - doc.token_num,
chunk_num=Knowledgebase.chunk_num - doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1,
).where(Knowledgebase.id == doc.kb_id).execute()
return True
@classmethod
@DB.connection_context()
def clear_chunk_num(cls, doc_id):
"""Deprecated: use delete_document_and_update_kb_counts instead."""
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = (
Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1)
.where(Knowledgebase.id == doc.kb_id)
.execute()
)
return num
@classmethod
@DB.connection_context()
def clear_chunk_num_when_rerun(cls, doc_id):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = (
Knowledgebase.update(
token_num=Knowledgebase.token_num - doc.token_num,
chunk_num=Knowledgebase.chunk_num - doc.chunk_num,
)
.where(Knowledgebase.id == doc.kb_id)
.execute()
)
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def get_knowledgebase_id(cls, doc_id):
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
docs = docs.dicts()
if not docs:
return None
return docs[0]["kb_id"]
@classmethod
@DB.connection_context()
def get_tenant_id_by_name(cls, name):
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
e, doc = cls.get_by_id(doc_id)
if not e:
return False
return KnowledgebaseService.accessible(doc.kb_id, user_id)
@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = (
cls.model.select(cls.model.id)
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
.join(UserTenant, on=((UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id)))
.where(cls.model.id == doc_id, UserTenant.status == StatusEnum.VALID.value, ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER)))
.paginate(0, 1)
)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):
docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
return docs[0]["embd_id"]
@classmethod
@DB.connection_context()
def get_tenant_embd_id(cls, doc_id):
docs = (
cls.model.select(Knowledgebase.tenant_embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
)
docs = docs.dicts()
if not docs:
return None
return docs[0]["tenant_embd_id"]
@classmethod
@DB.connection_context()
def get_chunking_config(cls, doc_id):
configs = (
cls.model.select(
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
cls.model.size,
cls.model.content_hash,
Knowledgebase.language,
Knowledgebase.embd_id,
Tenant.id.alias("tenant_id"),
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,
)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
)
configs = configs.dicts()
if not configs:
return None
return configs[0]
@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
fields = [cls.model.id]
doc_id = cls.model.select(*fields).where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return None
return doc_id[0]["id"]
@classmethod
@DB.connection_context()
def get_doc_ids_by_doc_names(cls, doc_names):
if not doc_names:
return []
query = cls.model.select(cls.model.id).where(cls.model.name.in_(doc_names))
return list(query.scalars().iterator())
@classmethod
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
if not config:
return
e, d = cls.get_by_id(id)
if not e:
raise LookupError(f"Document({id}) not found.")
def dfs_update(old, new):
for k, v in new.items():
if k not in old:
old[k] = v
continue
if isinstance(v, dict) and isinstance(old[k], dict):
dfs_update(old[k], v)
else:
old[k] = v
dfs_update(d.parser_config, config)
if not config.get("raptor") and d.parser_config.get("raptor"):
del d.parser_config["raptor"]
cls.update_by_id(id, {"parser_config": d.parser_config})
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id)
return len(docs)
@classmethod
@DB.connection_context()
def begin2parse(cls, doc_id, keep_progress=False):
info = {
"progress_msg": "Task is queued...",
"process_begin_at": get_format_time(),
}
if not keep_progress:
info["progress"] = random.random() * 1 / 100.0
info["run"] = TaskStatus.RUNNING.value
# keep the doc in DONE state when keep_progress=True for GraphRAG, RAPTOR and Mindmap tasks
cls.update_by_id(doc_id, info)
@classmethod
@DB.connection_context()
def update_progress(cls):
docs = cls.get_unfinished_docs()
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def update_progress_immediately(cls, docs: list[dict]):
if not docs:
return
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def _sync_progress(cls, docs: list[dict]):
from api.db.services.task_service import TaskService
for d in docs:
try:
tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time)
if not tsks:
continue
msg = []
prg = 0
finished = True
bad = 0
e, doc = DocumentService.get_by_id(d["id"])
status = doc.run # TaskStatus.RUNNING.value
if status == TaskStatus.CANCEL.value:
continue
doc_progress = doc.progress if doc and doc.progress else 0.0
special_task_running = False
priority = 0
for t in tsks:
task_type = (t.task_type or "").lower()
if task_type in PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES:
special_task_running = True
if 0 <= t.progress < 1:
finished = False
if t.progress == -1:
bad += 1
prg += t.progress if t.progress >= 0 else 0
if t.progress_msg.strip():
msg.append(t.progress_msg)
priority = max(priority, t.priority)
prg /= len(tsks)
if finished and bad:
prg = -1
status = TaskStatus.FAIL.value
elif finished:
prg = 1
status = TaskStatus.DONE.value
elif not finished:
status = TaskStatus.RUNNING.value
# only for special task and parsed docs and unfinished
freeze_progress = special_task_running and doc_progress >= 1 and not finished
msg = "\n".join(sorted(msg))
begin_at = d.get("process_begin_at")
if not begin_at:
begin_at = datetime.now()
# fallback
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
info = {"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), "run": status}
if prg != 0 and not freeze_progress:
info["progress"] = prg
if msg:
info["progress_msg"] = msg
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
info["progress_msg"] += "\n%d tasks are ahead in the queue..." % get_queue_length(priority)
else:
info["progress_msg"] = "%d tasks are ahead in the queue..." % get_queue_length(priority)
info["update_time"] = current_timestamp()
info["update_date"] = get_format_time()
(cls.model.update(info).where((cls.model.id == d["id"]) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))).execute())
except Exception as e:
if str(e).find("'0'") < 0:
logging.exception("fetch task exception")
@classmethod
@DB.connection_context()
def get_kb_doc_count(cls, kb_id):
return cls.model.select().where(cls.model.kb_id == kb_id).count()
@classmethod
@DB.connection_context()
def get_all_kb_doc_count(cls):
result = {}
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias("count")).group_by(cls.model.kb_id)
for row in rows:
result[row.kb_id] = row.count
return result
@classmethod
@DB.connection_context()
def do_cancel(cls, doc_id):
try:
_, doc = DocumentService.get_by_id(doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except Exception:
pass
return False
@classmethod
@DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
# cancelled: run == "2"
cancelled = cls.model.select(fn.COUNT(1)).where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)).scalar()
downloaded = cls.model.select(fn.COUNT(1)).where(cls.model.kb_id == kb_id, cls.model.source_type != "local").scalar()
row = (
cls.model.select(
# finished: progress == 1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"),
# failed: progress == -1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"),
# processing: 0 <= progress < 1
fn.COALESCE(
fn.SUM(
Case(
None,
[
(((cls.model.progress == 0) | ((cls.model.progress > 0) & (cls.model.progress < 1))), 1),
],
0,
)
),
0,
).alias("processing"),
)
.where((cls.model.kb_id == kb_id) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL)))
.dicts()
.get()
)
return {"processing": int(row["processing"]), "finished": int(row["finished"]), "failed": int(row["failed"]), "cancelled": int(cancelled), "downloaded": int(downloaded)}
@classmethod
def run(cls, tenant_id: str, doc: dict, kb_table_num_map: dict):
from api.db.services.task_service import queue_dataflow, queue_tasks
from api.db.services.file2document_service import File2DocumentService
doc["tenant_id"] = tenant_id
doc_parser = doc.get("parser_id", ParserType.NAIVE)
if doc_parser == ParserType.TABLE:
kb_id = doc.get("kb_id")
if not kb_id:
return
if kb_id not in kb_table_num_map:
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
kb_table_num_map[kb_id] = count
if kb_table_num_map[kb_id] <= 0:
KnowledgebaseService.delete_field_map(kb_id)
if doc.get("pipeline_id", ""):
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=doc["id"])
else:
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name, 0)
def queue_raptor_o_graphrag_tasks(sample_doc, ty, priority, fake_doc_id="", doc_ids=[]):
"""
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
Optionally, specify a list of doc_ids to determine which documents participate in the task.
"""
assert ty in ["graphrag", "raptor", "mindmap"], "type should be graphrag, raptor or mindmap"
chunking_config = DocumentService.get_chunking_config(sample_doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
hasher.update(str(chunking_config[field]).encode("utf-8"))
def new_task():
return {
"id": get_uuid(),
"doc_id": fake_doc_id,
"from_page": MAXIMUM_TASK_PAGE_NUMBER,
"to_page": MAXIMUM_TASK_PAGE_NUMBER,
"task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
task = new_task()
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
task["doc_ids"] = doc_ids
DocumentService.begin2parse(task["doc_id"], keep_progress=True)
assert REDIS_CONN.queue_product(settings.get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
return task["id"]
def get_queue_length(priority):
group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
if not group_info:
return 0
return int(group_info.get("lag", 0) or 0)