mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-05 09:47:47 +08:00
Feat/tenant model (#13072)
### What problem does this PR solve? Add id for table tenant_llm and apply in LLMBundle. ### Type of change - [x] Refactoring --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_tenant_default_model_by_type, get_model_config_by_type_and_name
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
@ -165,13 +166,21 @@ async def set():
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
tenant_embd_id = DocumentService.get_tenant_embd_id(req["doc_id"])
|
||||
if tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(tenant_embd_id)
|
||||
else:
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
if embd_id:
|
||||
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
else:
|
||||
embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING)
|
||||
embd_mdl = LLMBundle(tenant_id, embd_model_config)
|
||||
|
||||
_d = d
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
@ -324,8 +333,16 @@ async def create():
|
||||
if kb.pagerank:
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
tenant_embd_id = DocumentService.get_tenant_embd_id(req["doc_id"])
|
||||
if tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(tenant_embd_id)
|
||||
else:
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
if embd_id:
|
||||
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
else:
|
||||
embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING)
|
||||
embd_mdl = LLMBundle(tenant_id, embd_model_config)
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
@ -375,11 +392,17 @@ async def retrieval_test():
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(user_id, LLMType.CHAT, search_config["chat_id"])
|
||||
else:
|
||||
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(user_id, chat_model_config)
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(user_id, chat_model_config)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
|
||||
@ -404,15 +427,25 @@ async def retrieval_test():
|
||||
_question = question
|
||||
if langs:
|
||||
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
if kb.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
|
||||
elif kb.embd_id:
|
||||
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
else:
|
||||
embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
|
||||
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||
if req.get("tenant_rerank_id"):
|
||||
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
elif req.get("rerank_id"):
|
||||
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(_question, [kb])
|
||||
@ -432,11 +465,12 @@ async def retrieval_test():
|
||||
)
|
||||
|
||||
if use_kg:
|
||||
default_chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
|
||||
ck = await settings.kg_retriever.retrieval(_question,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
LLMBundle(kb.tenant_id, default_chat_model_config))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
|
||||
|
||||
Reference in New Issue
Block a user