Feat/tenant model (#13072)

### What problem does this PR solve?

Add id for table tenant_llm and apply in LLMBundle.

### Type of change

- [x] Refactoring

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
Lynn
2026-03-05 17:27:17 +08:00
committed by GitHub
parent 47540a4147
commit 62cb292635
54 changed files with 1754 additions and 361 deletions

View File

@ -34,6 +34,7 @@ from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from common.time_utils import current_timestamp, datetime_format
from common.text_utils import normalize_arabic_digits
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
@ -179,6 +180,28 @@ class DialogService(CommonService):
offset += limit
return res
@classmethod
@DB.connection_context()
def get_null_tenant_llm_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.llm_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
return list(objs)
@classmethod
@DB.connection_context()
def get_null_tenant_rerank_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.rerank_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null())
return list(objs)
async def async_chat_solo(dialog, messages, stream=True):
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
@ -191,22 +214,15 @@ async def async_chat_solo(dialog, messages, stream=True):
else:
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
attachments = "\n\n".join(text_attachments)
if llm_type == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
if llm_type == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
model_config = get_model_config_by_id(dialog.tenant_llm_id)
chat_mdl = LLMBundle(dialog.tenant_id, model_config)
factory = model_config.get("llm_factory", "") if model_config else ""
prompt_config = dialog.prompt_config
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
default_tts_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
if attachments and msg:
msg[-1]["content"] += attachments
@ -241,20 +257,27 @@ def get_models(dialog):
raise Exception("**ERROR**: Knowledge bases use different embedding models.")
if embedding_list:
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
embd_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
embd_mdl = LLMBundle(dialog.tenant_id, embd_model_config)
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
if dialog.tenant_llm_id:
chat_model_config = get_model_config_by_id(dialog.tenant_llm_id)
elif dialog.llm_id:
chat_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config)
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
rerank_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config)
if dialog.prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
default_tts_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config)
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
@ -603,8 +626,9 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
LLMBundle(dialog.tenant_id, default_chat_model))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
@ -1341,11 +1365,13 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embedding_list[0])
embd_mdl = LLMBundle(tenant_id, embd_model_config)
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_llm_name)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
@ -1417,13 +1443,22 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
kbs = KnowledgebaseService.get_by_ids(kb_ids)
if not kbs:
return {"error": "No KB selected"}
embedding_list = list(set([kb.embd_id for kb in kbs]))
tenant_embedding_list = list(set([kb.tenant_embd_id for kb in kbs]))
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
if tenant_embedding_list[0]:
embd_model_config = get_model_config_by_id(tenant_embedding_list[0])
else:
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, kbs[0].embd_id)
embd_mdl = LLMBundle(tenant_id, embd_model_config)
chat_id = search_config.get("chat_id", "")
if chat_id:
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id)
else:
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)

View File

@ -661,6 +661,19 @@ class DocumentService(CommonService):
return None
return docs[0]["embd_id"]
@classmethod
@DB.connection_context()
def get_tenant_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
return docs[0]["tenant_embd_id"]
@classmethod
@DB.connection_context()
def get_chunking_config(cls, doc_id):
@ -1007,6 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from rag.app import audio, email, naive, picture, presentation
e, conv = ConversationService.get_by_id(conversation_id)
@ -1022,8 +1036,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this dataset!")
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, lang=kb.language)
err, files = FileService.upload_document(kb, file_objs, user_id)
assert not err, "\n".join(err)
@ -1101,7 +1118,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
try_create_idx = True
_, tenant = TenantService.get_by_id(kb.tenant_id)
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
tenant_llm_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
llm_bdl = LLMBundle(kb.tenant_id, tenant_llm_config)
for doc_id in docids:
cks = [c for c in docs if c["doc_id"] == doc_id]

View File

@ -564,3 +564,14 @@ class KnowledgebaseService(CommonService):
'update_date': datetime_format(datetime.now())
}
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
@classmethod
@DB.connection_context()
def get_null_tenant_embd_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.embd_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null())
return list(objs)

View File

@ -83,18 +83,18 @@ def get_init_tenant_llm(user_id):
class LLMBundle(LLM4Tenant):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
super().__init__(tenant_id, model_config, lang, **kwargs)
def bind_tools(self, toolcall_session, tools):
if not self.is_tools:
logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
logging.warning(f"Model {self.model_config['llm_name']} does not support tool call, but you have assigned one or more tools to it!")
return
self.mdl.bind_tools(toolcall_session, tools)
def encode(self, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.model_config["llm_name"], input={"texts": texts})
safe_texts = []
for text in texts:
@ -106,9 +106,9 @@ class LLMBundle(LLM4Tenant):
safe_texts.append(text)
embeddings, used_tokens = self.mdl.encode(safe_texts)
llm_name = getattr(self, "llm_name", None)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
if self.model_config["llm_factory"] == "Builtin":
logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(texts, len(embeddings), used_tokens))
elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.encode can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
if self.langfuse:
@ -119,11 +119,12 @@ class LLMBundle(LLM4Tenant):
def encode_queries(self, query: str):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
emd, used_tokens = self.mdl.encode_queries(query)
llm_name = getattr(self, "llm_name", None)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
if self.model_config["llm_factory"] == "Builtin":
logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(query, len(emd), used_tokens))
elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.encode_queries can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
if self.langfuse:
@ -134,10 +135,10 @@ class LLMBundle(LLM4Tenant):
def similarity(self, query: str, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
sim, used_tokens = self.mdl.similarity(query, texts)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
@ -148,10 +149,10 @@ class LLMBundle(LLM4Tenant):
def describe(self, image, max_tokens=300):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.model_config["llm_name"]})
txt, used_tokens = self.mdl.describe(image)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
@ -162,10 +163,10 @@ class LLMBundle(LLM4Tenant):
def describe_with_prompt(self, image, prompt):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
@ -176,10 +177,10 @@ class LLMBundle(LLM4Tenant):
def transcription(self, audio):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.model_config["llm_name"]})
txt, used_tokens = self.mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
@ -196,7 +197,7 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name},
metadata={"model": self.model_config["llm_name"]},
)
final_text = ""
used_tokens = 0
@ -215,7 +216,7 @@ class LLMBundle(LLM4Tenant):
finally:
if final_text:
used_tokens = num_tokens_from_string(final_text)
TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens)
TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens)
if self.langfuse:
generation.update(
@ -230,11 +231,11 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name},
metadata={"model": self.model_config["llm_name"]},
)
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
if self.langfuse:
@ -256,7 +257,7 @@ class LLMBundle(LLM4Tenant):
for chunk in self.mdl.tts(text):
if isinstance(chunk, int):
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], chunk):
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
@ -375,7 +376,7 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
@ -392,8 +393,8 @@ class LLMBundle(LLM4Tenant):
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if used_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], used_tokens))
if generation:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
@ -413,7 +414,7 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
@ -437,8 +438,8 @@ class LLMBundle(LLM4Tenant):
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
@ -456,7 +457,7 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
@ -480,8 +481,8 @@ class LLMBundle(LLM4Tenant):
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()

View File

@ -107,7 +107,7 @@ class MemoryService(CommonService):
@classmethod
@DB.connection_context()
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, llm_id: str):
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, tenant_embd_id: int, llm_id: str, tenant_llm_id: int):
# Deduplicate name within tenant
memory_name = duplicate_name(
cls.query,
@ -126,7 +126,9 @@ class MemoryService(CommonService):
"memory_type": calculate_memory_type(memory_type),
"tenant_id": tenant_id,
"embd_id": embd_id,
"tenant_embd_id": tenant_embd_id,
"llm_id": llm_id,
"tenant_llm_id": tenant_llm_id,
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
"create_time": timestamp,
"create_date": format_time,
@ -168,3 +170,25 @@ class MemoryService(CommonService):
@DB.connection_context()
def delete_memory(cls, memory_id: str):
return cls.delete_by_id(memory_id)
@classmethod
@DB.connection_context()
def get_null_tenant_embd_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.embd_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null())
return list(objs)
@classmethod
@DB.connection_context()
def get_null_tenant_llm_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.llm_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
return list(objs)

View File

@ -60,7 +60,7 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
fields = [cls.model.id, cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@ -133,34 +133,35 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
def model_instance(cls, model_config: dict, lang="Chinese", **kwargs):
if not model_config:
raise LookupError("Model config is required")
kwargs.update({"provider": model_config["llm_factory"]})
if llm_type == LLMType.EMBEDDING.value:
if model_config["model_type"] == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return None
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
elif llm_type == LLMType.RERANK:
elif model_config["model_type"] == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return None
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
elif llm_type == LLMType.IMAGE2TEXT.value:
elif model_config["model_type"] == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return None
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
elif llm_type == LLMType.CHAT.value:
elif model_config["model_type"] == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return None
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
elif llm_type == LLMType.SPEECH2TEXT:
elif model_config["model_type"] == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return None
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
elif llm_type == LLMType.TTS:
elif model_config["model_type"] == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return None
return TTSModel[model_config["llm_factory"]](
@ -169,7 +170,7 @@ class TenantLLMService(CommonService):
base_url=model_config["api_base"],
)
elif llm_type == LLMType.OCR:
elif model_config["model_type"] == LLMType.OCR:
if model_config["llm_factory"] not in OcrModel:
return None
return OcrModel[model_config["llm_factory"]](
@ -218,6 +219,16 @@ class TenantLLMService(CommonService):
return num
@classmethod
@DB.connection_context()
def increase_usage_by_id(cls, tenant_model_id: int, used_tokens: int):
try:
update_cnt = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens).where(cls.model.id == tenant_model_id).execute()
except Exception as e:
logging.exception(f"TenantLLMService.increase_usage got exception {e}, Failed to update used_tokens for tenant_model_id {tenant_model_id}")
return 0
return update_cnt
@classmethod
@DB.connection_context()
def get_openai_models(cls):
@ -376,13 +387,12 @@ class TenantLLMService(CommonService):
class LLM4Tenant:
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.llm_name = model_config["llm_name"]
self.model_config = model_config
self.mdl = TenantLLMService.model_instance(model_config, lang=lang, **kwargs)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, model_config["llm_type"], model_config["llm_name"])
self.max_length = model_config.get("max_tokens", 8192)
self.is_tools = model_config.get("is_tools", False)

View File

@ -226,6 +226,12 @@ class TenantService(CommonService):
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
return int(hash_obj.hexdigest(), 16)%len(settings.MINIO)
@classmethod
@DB.connection_context()
def get_null_tenant_model_id_rows(cls):
objs = cls.model.select().orwhere(cls.model.tenant_llm_id.is_null(), cls.model.tenant_embd_id.is_null(), cls.model.tenant_asr_id.is_null(), cls.model.tenant_tts_id.is_null(), cls.model.tenant_rerank_id.is_null(), cls.model.tenant_img2txt_id.is_null())
return list(objs)
class UserTenantService(CommonService):
"""Service class for managing user-tenant relationship operations.