mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-05 09:47:47 +08:00
Feat/tenant model (#13072)
### What problem does this PR solve? Add id for table tenant_llm and apply in LLMBundle. ### Type of change - [x] Refactoring --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
@ -34,6 +34,7 @@ from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from common.text_utils import normalize_arabic_digits
|
||||
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
@ -179,6 +180,28 @@ class DialogService(CommonService):
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_llm_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.llm_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_rerank_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.rerank_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
|
||||
async def async_chat_solo(dialog, messages, stream=True):
|
||||
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
|
||||
@ -191,22 +214,15 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
else:
|
||||
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
|
||||
attachments = "\n\n".join(text_attachments)
|
||||
|
||||
if llm_type == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
|
||||
|
||||
if llm_type == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
model_config = get_model_config_by_id(dialog.tenant_llm_id)
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, model_config)
|
||||
factory = model_config.get("llm_factory", "") if model_config else ""
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
tts_mdl = None
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
default_tts_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model)
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
@ -241,20 +257,27 @@ def get_models(dialog):
|
||||
raise Exception("**ERROR**: Knowledge bases use different embedding models.")
|
||||
|
||||
if embedding_list:
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
embd_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, embd_model_config)
|
||||
if not embd_mdl:
|
||||
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
|
||||
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
if dialog.tenant_llm_id:
|
||||
chat_model_config = get_model_config_by_id(dialog.tenant_llm_id)
|
||||
elif dialog.llm_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||||
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config)
|
||||
|
||||
if dialog.rerank_id:
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||
rerank_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config)
|
||||
|
||||
if dialog.prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
default_tts_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config)
|
||||
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
|
||||
|
||||
|
||||
@ -603,8 +626,9 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if prompt_config.get("use_kg"):
|
||||
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||||
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
LLMBundle(dialog.tenant_id, default_chat_model))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
@ -1341,11 +1365,13 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
|
||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
embd_mdl = LLMBundle(tenant_id, embd_model_config)
|
||||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
if rerank_id:
|
||||
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
|
||||
max_tokens = chat_mdl.max_length
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
|
||||
@ -1417,13 +1443,22 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
if not kbs:
|
||||
return {"error": "No KB selected"}
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
tenant_embedding_list = list(set([kb.tenant_embd_id for kb in kbs]))
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
if tenant_embedding_list[0]:
|
||||
embd_model_config = get_model_config_by_id(tenant_embedding_list[0])
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, kbs[0].embd_id)
|
||||
embd_mdl = LLMBundle(tenant_id, embd_model_config)
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id)
|
||||
else:
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
if rerank_id:
|
||||
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
|
||||
|
||||
@ -661,6 +661,19 @@ class DocumentService(CommonService):
|
||||
return None
|
||||
return docs[0]["embd_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_embd_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_embd_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return None
|
||||
return docs[0]["tenant_embd_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_chunking_config(cls, doc_id):
|
||||
@ -1007,6 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
|
||||
from rag.app import audio, email, naive, picture, presentation
|
||||
|
||||
e, conv = ConversationService.get_by_id(conversation_id)
|
||||
@ -1022,8 +1036,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this dataset!")
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
||||
if kb.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, lang=kb.language)
|
||||
|
||||
err, files = FileService.upload_document(kb, file_objs, user_id)
|
||||
assert not err, "\n".join(err)
|
||||
@ -1101,7 +1118,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
try_create_idx = True
|
||||
|
||||
_, tenant = TenantService.get_by_id(kb.tenant_id)
|
||||
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
|
||||
tenant_llm_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
llm_bdl = LLMBundle(kb.tenant_id, tenant_llm_config)
|
||||
for doc_id in docids:
|
||||
cks = [c for c in docs if c["doc_id"] == doc_id]
|
||||
|
||||
|
||||
@ -564,3 +564,14 @@ class KnowledgebaseService(CommonService):
|
||||
'update_date': datetime_format(datetime.now())
|
||||
}
|
||||
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_embd_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.embd_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
@ -83,18 +83,18 @@ def get_init_tenant_llm(user_id):
|
||||
|
||||
|
||||
class LLMBundle(LLM4Tenant):
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
super().__init__(tenant_id, model_config, lang, **kwargs)
|
||||
|
||||
def bind_tools(self, toolcall_session, tools):
|
||||
if not self.is_tools:
|
||||
logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
|
||||
logging.warning(f"Model {self.model_config['llm_name']} does not support tool call, but you have assigned one or more tools to it!")
|
||||
return
|
||||
self.mdl.bind_tools(toolcall_session, tools)
|
||||
|
||||
def encode(self, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.model_config["llm_name"], input={"texts": texts})
|
||||
|
||||
safe_texts = []
|
||||
for text in texts:
|
||||
@ -106,9 +106,9 @@ class LLMBundle(LLM4Tenant):
|
||||
safe_texts.append(text)
|
||||
|
||||
embeddings, used_tokens = self.mdl.encode(safe_texts)
|
||||
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
||||
if self.model_config["llm_factory"] == "Builtin":
|
||||
logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(texts, len(embeddings), used_tokens))
|
||||
elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.encode can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -119,11 +119,12 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def encode_queries(self, query: str):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
|
||||
|
||||
emd, used_tokens = self.mdl.encode_queries(query)
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
||||
if self.model_config["llm_factory"] == "Builtin":
|
||||
logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(query, len(emd), used_tokens))
|
||||
elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.encode_queries can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -134,10 +135,10 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
|
||||
|
||||
sim, used_tokens = self.mdl.similarity(query, texts)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -148,10 +149,10 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.model_config["llm_name"]})
|
||||
|
||||
txt, used_tokens = self.mdl.describe(image)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -162,10 +163,10 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def describe_with_prompt(self, image, prompt):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
|
||||
|
||||
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -176,10 +177,10 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def transcription(self, audio):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.model_config["llm_name"]})
|
||||
|
||||
txt, used_tokens = self.mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -196,7 +197,7 @@ class LLMBundle(LLM4Tenant):
|
||||
generation = self.langfuse.start_generation(
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.llm_name},
|
||||
metadata={"model": self.model_config["llm_name"]},
|
||||
)
|
||||
final_text = ""
|
||||
used_tokens = 0
|
||||
@ -215,7 +216,7 @@ class LLMBundle(LLM4Tenant):
|
||||
finally:
|
||||
if final_text:
|
||||
used_tokens = num_tokens_from_string(final_text)
|
||||
TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens)
|
||||
TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens)
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
@ -230,11 +231,11 @@ class LLMBundle(LLM4Tenant):
|
||||
generation = self.langfuse.start_generation(
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.llm_name},
|
||||
metadata={"model": self.model_config["llm_name"]},
|
||||
)
|
||||
|
||||
full_text, used_tokens = mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
|
||||
|
||||
if self.langfuse:
|
||||
@ -256,7 +257,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
for chunk in self.mdl.tts(text):
|
||||
if isinstance(chunk, int):
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], chunk):
|
||||
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
return
|
||||
yield chunk
|
||||
@ -375,7 +376,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
|
||||
chat_partial = partial(base_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
@ -392,8 +393,8 @@ class LLMBundle(LLM4Tenant):
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
if used_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], used_tokens))
|
||||
|
||||
if generation:
|
||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||
@ -413,7 +414,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
@ -437,8 +438,8 @@ class LLMBundle(LLM4Tenant):
|
||||
generation.update(output={"error": str(e)})
|
||||
generation.end()
|
||||
raise
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens))
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
@ -456,7 +457,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
@ -480,8 +481,8 @@ class LLMBundle(LLM4Tenant):
|
||||
generation.update(output={"error": str(e)})
|
||||
generation.end()
|
||||
raise
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens))
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
|
||||
@ -107,7 +107,7 @@ class MemoryService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, llm_id: str):
|
||||
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, tenant_embd_id: int, llm_id: str, tenant_llm_id: int):
|
||||
# Deduplicate name within tenant
|
||||
memory_name = duplicate_name(
|
||||
cls.query,
|
||||
@ -126,7 +126,9 @@ class MemoryService(CommonService):
|
||||
"memory_type": calculate_memory_type(memory_type),
|
||||
"tenant_id": tenant_id,
|
||||
"embd_id": embd_id,
|
||||
"tenant_embd_id": tenant_embd_id,
|
||||
"llm_id": llm_id,
|
||||
"tenant_llm_id": tenant_llm_id,
|
||||
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
|
||||
"create_time": timestamp,
|
||||
"create_date": format_time,
|
||||
@ -168,3 +170,25 @@ class MemoryService(CommonService):
|
||||
@DB.connection_context()
|
||||
def delete_memory(cls, memory_id: str):
|
||||
return cls.delete_by_id(memory_id)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_embd_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.embd_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_llm_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.llm_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
@ -60,7 +60,7 @@ class TenantLLMService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
|
||||
fields = [cls.model.id, cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
@ -133,34 +133,35 @@ class TenantLLMService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
def model_instance(cls, model_config: dict, lang="Chinese", **kwargs):
|
||||
if not model_config:
|
||||
raise LookupError("Model config is required")
|
||||
kwargs.update({"provider": model_config["llm_factory"]})
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["model_type"] == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return None
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
elif llm_type == LLMType.RERANK:
|
||||
elif model_config["model_type"] == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return None
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
elif model_config["model_type"] == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return None
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
elif model_config["model_type"] == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return None
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
elif llm_type == LLMType.SPEECH2TEXT:
|
||||
elif model_config["model_type"] == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return None
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
elif llm_type == LLMType.TTS:
|
||||
elif model_config["model_type"] == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return None
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
@ -169,7 +170,7 @@ class TenantLLMService(CommonService):
|
||||
base_url=model_config["api_base"],
|
||||
)
|
||||
|
||||
elif llm_type == LLMType.OCR:
|
||||
elif model_config["model_type"] == LLMType.OCR:
|
||||
if model_config["llm_factory"] not in OcrModel:
|
||||
return None
|
||||
return OcrModel[model_config["llm_factory"]](
|
||||
@ -218,6 +219,16 @@ class TenantLLMService(CommonService):
|
||||
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increase_usage_by_id(cls, tenant_model_id: int, used_tokens: int):
|
||||
try:
|
||||
update_cnt = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens).where(cls.model.id == tenant_model_id).execute()
|
||||
except Exception as e:
|
||||
logging.exception(f"TenantLLMService.increase_usage got exception {e}, Failed to update used_tokens for tenant_model_id {tenant_model_id}")
|
||||
return 0
|
||||
return update_cnt
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
@ -376,13 +387,12 @@ class TenantLLMService(CommonService):
|
||||
|
||||
|
||||
class LLM4Tenant:
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.llm_name = model_config["llm_name"]
|
||||
self.model_config = model_config
|
||||
self.mdl = TenantLLMService.model_instance(model_config, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, model_config["llm_type"], model_config["llm_name"])
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
self.is_tools = model_config.get("is_tools", False)
|
||||
|
||||
@ -226,6 +226,12 @@ class TenantService(CommonService):
|
||||
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
|
||||
return int(hash_obj.hexdigest(), 16)%len(settings.MINIO)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_model_id_rows(cls):
|
||||
objs = cls.model.select().orwhere(cls.model.tenant_llm_id.is_null(), cls.model.tenant_embd_id.is_null(), cls.model.tenant_asr_id.is_null(), cls.model.tenant_tts_id.is_null(), cls.model.tenant_rerank_id.is_null(), cls.model.tenant_img2txt_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
|
||||
class UserTenantService(CommonService):
|
||||
"""Service class for managing user-tenant relationship operations.
|
||||
|
||||
Reference in New Issue
Block a user