mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-06 02:07:49 +08:00
Feat/tenant model (#13072)
### What problem does this PR solve? Add id for table tenant_llm and apply in LLMBundle. ### Type of change - [x] Refactoring --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
@ -65,6 +65,7 @@ from api.db.services.doc_metadata_service import DocMetadataService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type
|
||||
from common.versions import get_ragflow_version
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||
@ -342,7 +343,8 @@ async def build_chunks(task, progress_callback):
|
||||
if task["parser_config"].get("auto_keywords", 0):
|
||||
st = timer()
|
||||
progress_callback(msg="Start to generate keywords for every chunk ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"])
|
||||
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
|
||||
|
||||
async def doc_keyword_extraction(chat_mdl, d, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
|
||||
@ -375,7 +377,8 @@ async def build_chunks(task, progress_callback):
|
||||
if task["parser_config"].get("auto_questions", 0):
|
||||
st = timer()
|
||||
progress_callback(msg="Start to generate questions for every chunk ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"])
|
||||
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
|
||||
|
||||
async def doc_question_proposal(chat_mdl, d, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
|
||||
@ -407,7 +410,8 @@ async def build_chunks(task, progress_callback):
|
||||
if task["parser_config"].get("enable_metadata", False) and task["parser_config"].get("metadata"):
|
||||
st = timer()
|
||||
progress_callback(msg="Start to generate meta-data for every chunk ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"])
|
||||
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
|
||||
|
||||
async def gen_metadata_task(chat_mdl, d):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
|
||||
@ -461,8 +465,8 @@ async def build_chunks(task, progress_callback):
|
||||
set_tags_to_cache(kb_ids, all_tags)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, task["llm_id"])
|
||||
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
|
||||
|
||||
docs_to_tag = []
|
||||
for d in docs:
|
||||
@ -517,7 +521,8 @@ async def build_chunks(task, progress_callback):
|
||||
|
||||
def build_TOC(task, docs, progress_callback):
|
||||
progress_callback(msg="Start to generate table of content ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"])
|
||||
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
|
||||
docs = sorted(docs, key=lambda d: (
|
||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||
@ -668,7 +673,8 @@ async def run_dataflow(task: dict):
|
||||
set_progress(task_id, prog=0.82, msg="\n-------------------------------------\nStart to embedding...")
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
embedding_id = kb.embd_id
|
||||
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
embd_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.EMBEDDING, embedding_id)
|
||||
embedding_model = LLMBundle(task["tenant_id"], embd_model_config)
|
||||
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
@ -985,7 +991,11 @@ async def do_handle_task(task):
|
||||
|
||||
try:
|
||||
# bind embedding model
|
||||
embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
|
||||
if task_embedding_id:
|
||||
embd_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.EMBEDDING, task_embedding_id)
|
||||
else:
|
||||
embd_model_config = get_tenant_default_model_by_type(task_tenant_id, LLMType.EMBEDDING)
|
||||
embedding_model = LLMBundle(task_tenant_id, embd_model_config, lang=task_language)
|
||||
vts, _ = embedding_model.encode(["ok"])
|
||||
vector_size = len(vts[0])
|
||||
except Exception as e:
|
||||
@ -1037,7 +1047,8 @@ async def do_handle_task(task):
|
||||
return
|
||||
|
||||
# bind LLM for raptor
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=kb_task_llm_id, lang=task_language)
|
||||
chat_model_config = get_model_config_by_type_and_name(task_dataset_id, LLMType.CHAT, kb_task_llm_id)
|
||||
chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language)
|
||||
# run RAPTOR
|
||||
async with kg_limiter:
|
||||
chunks, token_count = await run_raptor_for_kb(
|
||||
@ -1081,7 +1092,8 @@ async def do_handle_task(task):
|
||||
|
||||
graphrag_conf = kb_parser_config.get("graphrag", {})
|
||||
start_ts = timer()
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=kb_task_llm_id, lang=task_language)
|
||||
chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id)
|
||||
chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language)
|
||||
with_resolution = graphrag_conf.get("resolution", False)
|
||||
with_community = graphrag_conf.get("community", False)
|
||||
async with kg_limiter:
|
||||
|
||||
Reference in New Issue
Block a user