mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-05 17:57:47 +08:00
Feat/tenant model (#13072)
### What problem does this PR solve? Add id for table tenant_llm and apply in LLMBundle. ### Type of change - [x] Refactoring --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
@ -36,6 +36,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_tenant_default_model_by_type, get_model_config_by_type_and_name
|
||||
from common.metadata_utils import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||
get_request_json
|
||||
@ -1248,8 +1249,13 @@ async def add_chunk(tenant_id, dataset_id, document_id):
|
||||
d["tag_kwd"] = req["tag_kwd"]
|
||||
if "tag_feas" in req:
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
embd_id = DocumentService.get_embd_id(document_id)
|
||||
embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
tenant_embd_id = DocumentService.get_tenant_embd_id(document_id)
|
||||
if tenant_embd_id:
|
||||
model_config = get_model_config_by_id(tenant_embd_id)
|
||||
else:
|
||||
embd_id = DocumentService.get_embd_id(document_id)
|
||||
model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
embd_mdl = TenantLLMService.model_instance(model_config)
|
||||
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
@ -1446,8 +1452,13 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
d["tag_kwd"] = req["tag_kwd"]
|
||||
if "tag_feas" in req:
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
embd_id = DocumentService.get_embd_id(document_id)
|
||||
embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
tenant_embd_id = DocumentService.get_tenant_embd_id(document_id)
|
||||
if tenant_embd_id:
|
||||
model_config = get_model_config_by_id(tenant_embd_id)
|
||||
else:
|
||||
embd_id = DocumentService.get_embd_id(document_id)
|
||||
model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
embd_mdl = TenantLLMService.model_instance(model_config)
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1]
|
||||
if len(arr) != 2:
|
||||
@ -1616,17 +1627,26 @@ async def retrieval_test(tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
if not e:
|
||||
return get_error_data_result(message="Dataset not found!")
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
|
||||
if kb.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||
if req.get("tenant_rerank_id"):
|
||||
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
elif req.get("rerank_id"):
|
||||
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
|
||||
if langs:
|
||||
question = await cross_languages(kb.tenant_id, None, question, langs)
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(kb.tenant_id, chat_model_config)
|
||||
question += await keyword_extraction(chat_mdl, question)
|
||||
|
||||
ranks = await settings.retriever.retrieval(
|
||||
@ -1645,13 +1665,15 @@ async def retrieval_test(tenant_id):
|
||||
rank_feature=label_question(question, kbs),
|
||||
)
|
||||
if toc_enhance:
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(kb.tenant_id, chat_model_config)
|
||||
cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
|
||||
if cks:
|
||||
ranks["chunks"] = cks
|
||||
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
|
||||
if use_kg:
|
||||
ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, chat_model_config))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user