Feat/tenant model (#13072)

### What problem does this PR solve?

Add id for table tenant_llm and apply in LLMBundle.

### Type of change

- [x] Refactoring

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
Lynn
2026-03-05 17:27:17 +08:00
committed by GitHub
parent 47540a4147
commit 62cb292635
54 changed files with 1754 additions and 361 deletions

View File

@ -24,7 +24,7 @@ from common.constants import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, get_model_config_by_type_and_name
from rag.graphrag.general.graph_extractor import GraphExtractor
from rag.graphrag.general.index import update_graph, with_resolution, with_community
from common import settings
@ -71,10 +71,14 @@ async def main():
)
]
_, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
llm_config = get_tenant_default_model_by_type(args.tenant_id, LLMType.CHAT)
llm_bdl = LLMBundle(args.tenant_id, llm_config)
_, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embed_bdl = LLMBundle(args.tenant_id, embd_model_config)
graph, doc_ids = await update_graph(
GraphExtractor,

View File

@ -24,7 +24,7 @@ from common.constants import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from rag.graphrag.general.index import update_graph
from rag.graphrag.light.graph_extractor import GraphExtractor
from common import settings
@ -72,10 +72,14 @@ async def main():
)
]
_, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
llm_config = get_tenant_default_model_by_type(args.tenant_id, LLMType.CHAT)
llm_bdl = LLMBundle(args.tenant_id, llm_config)
_, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embed_bdl = LLMBundle(args.tenant_id, embd_model_config)
graph, doc_ids = await update_graph(
GraphExtractor,

View File

@ -318,7 +318,7 @@ if __name__ == "__main__":
from common.constants import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, get_model_config_by_type_and_name
from rag.nlp import search
settings.init_settings()
@ -329,10 +329,14 @@ if __name__ == "__main__":
args = parser.parse_args()
kb_id = args.kb_id
_, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
llm_config = get_tenant_default_model_by_type(args.tenant_id, LLMType.CHAT)
llm_bdl = LLMBundle(args.tenant_id, llm_config)
_, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embed_bdl = LLMBundle(args.tenant_id, embd_model_config)
kg = KGSearch(settings.docStoreConn)
print(asyncio.run(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},