Feat/tenant model (#13072)

### What problem does this PR solve?

Add id for table tenant_llm and apply in LLMBundle.

### Type of change

- [x] Refactoring

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
Lynn
2026-03-05 17:27:17 +08:00
committed by GitHub
parent 47540a4147
commit 62cb292635
54 changed files with 1754 additions and 361 deletions

View File

@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_tenant_default_model_by_type, get_model_config_by_type_and_name
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
@ -165,13 +166,21 @@ async def set():
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
tenant_embd_id = DocumentService.get_tenant_embd_id(req["doc_id"])
if tenant_embd_id:
embd_model_config = get_model_config_by_id(tenant_embd_id)
else:
embd_id = DocumentService.get_embd_id(req["doc_id"])
if embd_id:
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
else:
embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(tenant_id, embd_model_config)
_d = d
if doc.parser_id == ParserType.QA:
arr = [
@ -324,8 +333,16 @@ async def create():
if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
tenant_embd_id = DocumentService.get_tenant_embd_id(req["doc_id"])
if tenant_embd_id:
embd_model_config = get_model_config_by_id(tenant_embd_id)
else:
embd_id = DocumentService.get_embd_id(req["doc_id"])
if embd_id:
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
else:
embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(tenant_id, embd_model_config)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
@ -375,11 +392,17 @@ async def retrieval_test():
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
chat_id = search_config.get("chat_id", "")
if chat_id:
chat_model_config = get_model_config_by_type_and_name(user_id, LLMType.CHAT, search_config["chat_id"])
else:
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
chat_mdl = LLMBundle(user_id, chat_model_config)
else:
meta_data_filter = req.get("meta_data_filter") or {}
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
chat_mdl = LLMBundle(user_id, chat_model_config)
if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
@ -404,15 +427,25 @@ async def retrieval_test():
_question = question
if langs:
_question = await cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
elif kb.embd_id:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
else:
embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if req.get("tenant_rerank_id"):
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
elif req.get("rerank_id"):
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
_question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb])
@ -432,11 +465,12 @@ async def retrieval_test():
)
if use_kg:
default_chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(_question,
tenant_ids,
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
LLMBundle(kb.tenant_id, default_chat_model_config))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)