Feat/tenant model (#13072)

### What problem does this PR solve?

Add id for table tenant_llm and apply in LLMBundle.

### Type of change

- [x] Refactoring

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
Lynn
2026-03-05 17:27:17 +08:00
committed by GitHub
parent 47540a4147
commit 62cb292635
54 changed files with 1754 additions and 361 deletions

View File

@ -43,6 +43,7 @@ from peewee import (
Metadata,
Model,
TextField,
PrimaryKeyField,
)
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
@ -737,11 +738,17 @@ class Tenant(DataBaseModel):
name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
public_key = CharField(max_length=255, null=True, index=True)
llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
tenant_llm_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID", index=True)
tenant_embd_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID", index=True)
tenant_asr_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID", index=True)
tenant_img2txt_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
rerank_id = CharField(max_length=128, null=False, help_text="default rerank model ID", index=True)
tenant_rerank_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
tts_id = CharField(max_length=256, null=True, help_text="default tts model ID", index=True)
tenant_tts_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
parser_ids = CharField(max_length=256, null=False, help_text="document processors", index=True)
credit = IntegerField(default=512, index=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
@ -808,14 +815,15 @@ class LLM(DataBaseModel):
class TenantLLM(DataBaseModel):
id = PrimaryKeyField()
tenant_id = CharField(max_length=32, null=False, index=True)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name", index=True)
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="", index=True)
api_key = TextField(null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base")
max_tokens = IntegerField(default=8192, index=True)
used_tokens = IntegerField(default=0, index=True)
max_tokens = IntegerField(default=8192, help_text="Max context token num", index=True)
used_tokens = IntegerField(default=0, help_text="Used token num", index=True)
status = CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
def __str__(self):
@ -823,7 +831,9 @@ class TenantLLM(DataBaseModel):
class Meta:
db_table = "tenant_llm"
primary_key = CompositeKey("tenant_id", "llm_factory", "llm_name")
indexes = (
(("tenant_id", "llm_factory", "llm_name"), True),
)
class TenantLangfuse(DataBaseModel):
@ -847,6 +857,7 @@ class Knowledgebase(DataBaseModel):
language = CharField(max_length=32, null=True, default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", help_text="English|Chinese", index=True)
description = TextField(null=True, help_text="KB description")
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID", index=True)
tenant_embd_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
created_by = CharField(max_length=32, null=False, index=True)
doc_num = IntegerField(default=0, index=True)
@ -954,6 +965,7 @@ class Dialog(DataBaseModel):
icon = TextField(null=True, help_text="icon base64 string")
language = CharField(max_length=32, null=True, default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", help_text="English|Chinese", index=True)
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
tenant_llm_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, "presence_penalty": 0.4, "max_tokens": 512})
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced", index=True)
@ -973,7 +985,7 @@ class Dialog(DataBaseModel):
do_refer = CharField(max_length=1, null=False, default="1", help_text="it needs to insert reference index into answer or not")
rerank_id = CharField(max_length=128, null=False, help_text="default rerank model ID")
tenant_rerank_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
kb_ids = JSONField(null=False, default=[])
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
@ -1295,7 +1307,9 @@ class Memory(DataBaseModel):
memory_type = IntegerField(null=False, default=1, index=True, help_text="Bit flags (LSB->MSB): 1=raw, 2=semantic, 4=episodic, 8=procedural. E.g., 5 enables raw + episodic.")
storage_type = CharField(max_length=32, default='table', null=False, index=True, help_text="table|graph")
embd_id = CharField(max_length=128, null=False, index=False, help_text="embedding model ID")
tenant_embd_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
llm_id = CharField(max_length=128, null=False, index=False, help_text="chat model ID")
tenant_llm_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
permissions = CharField(max_length=16, null=False, index=True, help_text="me|team", default="me")
description = TextField(null=True, help_text="description")
memory_size = IntegerField(default=5242880, null=False, index=False)
@ -1351,6 +1365,23 @@ def alter_db_rename_column(migrator, table_name, old_column_name, new_column_nam
def migrate_add_unique_email(migrator):
"""Deduplicates user emails and add UNIQUE constraint to email column (idempotent)"""
# step 0: check if UNIQUE index on email already exists
try:
cursor = DB.execute_sql("""
SELECT COUNT(*)
FROM information_schema.statistics
WHERE table_schema = DATABASE()
AND table_name = 'user'
AND index_name = 'user_email'
AND non_unique = 0
""")
result = cursor.fetchone()
if result and result[0] > 0:
logging.info("UNIQUE index on user.email already exists, skipping migration")
return
except Exception as ex:
logging.warning("Failed to check if UNIQUE index exists on user.email: %s, continuing with migration", ex)
# step 1: rename duplicate rows so the UNIQUE constraint can be applied
try:
duplicates = User.select(User.email).group_by(User.email).having(fn.COUNT(User.id) > 1).tuples()
@ -1385,6 +1416,62 @@ def migrate_add_unique_email(migrator):
logging.critical("Failed to add UNIQUE constraint on user.email: %s", ex)
def update_tenant_llm_to_id_primary_key():
"""Add ID and set to primary key step by step."""
try:
with DB.atomic():
# 0. Check if exist ID
cursor = DB.execute_sql("""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = 'tenant_llm'
AND COLUMN_NAME = 'id'
""")
if cursor.rowcount > 0:
return
# 1. Add nullable column
DB.execute_sql("ALTER TABLE tenant_llm ADD COLUMN temp_id INT NULL")
# 2. Set ID
DB.execute_sql("SET @row = 0;")
DB.execute_sql("UPDATE tenant_llm SET temp_id = (@row := @row + 1) ORDER BY tenant_id, llm_factory, llm_name;")
# 3. Drop old primary key
DB.execute_sql("ALTER TABLE tenant_llm DROP PRIMARY KEY")
# 4. Update ID column to primary key
DB.execute_sql("""
ALTER TABLE tenant_llm
MODIFY COLUMN temp_id INT NOT NULL AUTO_INCREMENT PRIMARY KEY
""")
# 5. Add unique key
DB.execute_sql("""
ALTER TABLE tenant_llm
ADD CONSTRAINT uk_tenant_llm UNIQUE (tenant_id, llm_factory, llm_name)
""")
# 6. rename
DB.execute_sql("ALTER TABLE tenant_llm RENAME COLUMN temp_id TO id")
logging.info("Successfully updated tenant_llm to id primary key.")
except Exception as e:
logging.error(str(e))
cursor = DB.execute_sql("""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = 'tenant_llm'
AND COLUMN_NAME = 'temp_id'
""")
if cursor.rowcount > 0:
DB.execute_sql("ALTER TABLE tenant_llm DROP COLUMN temp_id")
def migrate_db():
logging.disable(logging.ERROR)
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
@ -1436,6 +1523,18 @@ def migrate_db():
alter_db_add_column(migrator, "api_4_conversation", "exp_user_id", CharField(max_length=255, null=True, help_text="exp_user_id", index=True))
# Migrate system_settings.value from CharField to TextField for longer sandbox configs
alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)"))
update_tenant_llm_to_id_primary_key()
alter_db_add_column(migrator, "tenant", "tenant_llm_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "tenant", "tenant_embd_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "tenant", "tenant_asr_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "tenant", "tenant_img2txt_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "tenant", "tenant_rerank_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "tenant", "tenant_tts_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "knowledgebase", "tenant_embd_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "dialog", "tenant_llm_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "dialog", "tenant_rerank_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "memory", "tenant_embd_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
alter_db_add_column(migrator, "memory", "tenant_llm_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
logging.disable(logging.NOTSET)
# this is after re-enabling logging to allow logging changed user emails
migrate_add_unique_email(migrator)

View File

@ -24,16 +24,19 @@ from copy import deepcopy
from peewee import IntegrityError
from api.db import UserTenantRole
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM, Knowledgebase, Dialog, Memory
from api.db.services import UserService
from api.db.services.canvas_service import CanvasTemplateService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
from api.db.services.system_settings_service import SystemSettingsService
from api.db.services.dialog_service import DialogService
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache, fix_missing_tokenized_memory
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import settings
@ -90,13 +93,15 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
if tenant["llm_id"]:
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
chat_model_config = get_tenant_default_model_by_type(tenant["id"], LLMType.CHAT)
chat_mdl = LLMBundle(tenant["id"], chat_model_config)
msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}))
if msg.find("ERROR: ") == 0:
logging.error("'{}' doesn't work. {}".format( tenant["llm_id"], msg))
if tenant["embd_id"]:
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
embd_model_config = get_tenant_default_model_by_type(tenant["id"], LLMType.EMBEDDING)
embd_mdl = LLMBundle(tenant["id"], embd_model_config)
v, c = embd_mdl.encode(["Hello!"])
if c == 0:
logging.error("'{}' doesn't work!".format(tenant["embd_id"]))
@ -185,6 +190,7 @@ def init_web_data():
init_message_id_sequence()
init_memory_size_cache()
fix_missing_tokenized_memory()
fix_empty_tenant_model_id()
logging.info("init web data success:{}".format(time.time() - start_time))
def init_table():
@ -213,6 +219,105 @@ def init_table():
raise e
def fix_empty_tenant_model_id():
# knowledgebase
empty_tenant_embd_id_kbs = KnowledgebaseService.get_null_tenant_embd_id_row()
if empty_tenant_embd_id_kbs:
logging.info(f"Found {len(empty_tenant_embd_id_kbs)} empty tenant_embd_id knowledgebase.")
kb_groups: dict = {}
for obj in empty_tenant_embd_id_kbs:
if kb_groups.get((obj.tenant_id, obj.embd_id)):
kb_groups[(obj.tenant_id, obj.embd_id)].append(obj.id)
else:
kb_groups[(obj.tenant_id, obj.embd_id)] = [obj.id]
update_cnt = 0
for k, v in kb_groups.items():
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
if tenant_llm:
update_cnt += KnowledgebaseService.filter_update([Knowledgebase.id.in_(v)], {"tenant_embd_id": tenant_llm.id})
logging.info(f"Update {update_cnt} tenant_embd_id in table knowledgebase.")
# dialog
empty_tenant_llm_id_dialog = DialogService.get_null_tenant_llm_id_row()
if empty_tenant_llm_id_dialog:
logging.info(f"Found {len(empty_tenant_llm_id_dialog)} empty tenant_llm_id dialogs.")
dialog_groups: dict = {}
for obj in empty_tenant_llm_id_dialog:
if dialog_groups.get((obj.tenant_id, obj.llm_id)):
dialog_groups[(obj.tenant_id, obj.llm_id)].append(obj.id)
else:
dialog_groups[(obj.tenant_id, obj.llm_id)] = [obj.id]
update_cnt = 0
for k, v in dialog_groups.items():
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
if tenant_llm:
update_cnt += DialogService.filter_update([Dialog.id.in_(v)], {"tenant_llm_id": tenant_llm.id})
logging.info(f"Update {update_cnt} tenant_llm_id in table dialog.")
empty_tenant_rerank_id_dialog = DialogService.get_null_tenant_rerank_id_row()
if empty_tenant_rerank_id_dialog:
logging.info(f"Found {len(empty_tenant_rerank_id_dialog)} empty tenant_rerank_id dialogs.")
dialog_groups: dict = {}
for obj in empty_tenant_rerank_id_dialog:
if dialog_groups.get((obj.tenant_id, obj.rerank_id)):
dialog_groups[(obj.tenant_id, obj.rerank_id)].append(obj.id)
else:
dialog_groups[(obj.tenant_id, obj.rerank_id)] = [obj.id]
update_cnt = 0
for k, v in dialog_groups.items():
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
if tenant_llm:
update_cnt += DialogService.filter_update([Dialog.id.in_(v)], {"tenant_rerank_id": tenant_llm.id})
logging.info(f"Update {update_cnt} tenant_rerank_id in table dialog.")
# memory
empty_tenant_embd_id_memories = MemoryService.get_null_tenant_embd_id_row()
if empty_tenant_embd_id_memories:
logging.info(f"Found {len(empty_tenant_embd_id_memories)} empty tenant_embd_id memories.")
memory_groups: dict = {}
for obj in empty_tenant_embd_id_memories:
if memory_groups.get((obj.tenant_id, obj.embd_id)):
memory_groups[(obj.tenant_id, obj.embd_id)].append(obj.id)
else:
memory_groups[(obj.tenant_id, obj.embd_id)] = [obj.id]
update_cnt = 0
for k, v in memory_groups.items():
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
if tenant_llm:
update_cnt += MemoryService.filter_update([Memory.id.in_(v)], {"tenant_embd_id": tenant_llm.id})
logging.info(f"Update {update_cnt} tenant_embd_id in table memory.")
empty_tenant_llm_id_memories = MemoryService.get_null_tenant_llm_id_row()
if empty_tenant_llm_id_memories:
logging.info(f"Found {len(empty_tenant_llm_id_memories)} empty tenant_llm_id memories.")
memory_groups: dict = {}
for obj in empty_tenant_llm_id_memories:
if memory_groups.get((obj.tenant_id, obj.llm_id)):
memory_groups[(obj.tenant_id, obj.llm_id)].append(obj.id)
else:
memory_groups[(obj.tenant_id, obj.llm_id)] = [obj.id]
update_cnt = 0
for k, v in memory_groups.items():
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
if tenant_llm:
update_cnt += MemoryService.filter_update([Memory.id.in_(v)], {"tenant_llm_id": tenant_llm.id})
logging.info(f"Update {update_cnt} tenant_llm_id in table memory.")
# tenant
empty_tenant_model_id_tenants = TenantService.get_null_tenant_model_id_rows()
if empty_tenant_model_id_tenants:
logging.info(f"Found {len(empty_tenant_model_id_tenants)} empty tenant_model_id tenants.")
update_cnt = 0
for obj in empty_tenant_model_id_tenants:
tenant_dict = obj.to_dict()
update_dict = {}
for key in ["llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"]:
if tenant_dict.get(key) and not tenant_dict.get(f"tenant_{key}"):
tenant_model = TenantLLMService.get_api_key(tenant_dict["id"], tenant_dict[key])
if tenant_model:
update_dict.update({f"tenant_{key}": tenant_model.id})
if update_dict:
update_dict += TenantService.update_by_id(tenant_dict["id"], update_dict)
logging.info(f"Update {update_cnt} tenant_model_id in table tenant.")
logging.info("Fix empty tenant_model_id done.")
if __name__ == '__main__':
init_web_db()
init_web_data()

View File

@ -25,8 +25,8 @@ from api.db.db_utils import bulk_insert_into_db
from api.db.db_models import Task
from api.db.services.task_service import TaskService
from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name
from api.utils.memory_utils import get_memory_type_human
from memory.services.messages import MessageService
from memory.services.query import MsgTextQuery, get_vector
@ -53,11 +53,12 @@ async def save_to_memory(memory_id: str, message_dict: dict):
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
tenant_id,
memory.llm_id,
memory.tenant_llm_id,
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
message_dict.get("agent_response", "")
message_dict.get("agent_response", ""),
llm_id=memory.llm_id
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
message_list = [{
@ -107,12 +108,13 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
tenant_id,
memory.llm_id,
memory.tenant_llm_id,
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
message_dict.get("agent_response", ""),
task_id=task_id
task_id=task_id,
llm_id=memory.llm_id
)
message_list = [{
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
@ -139,11 +141,8 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
return await embed_and_save(memory, message_list, task_id)
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None) -> List[dict]:
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
if not llm_type:
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, memory_type: List[str], user_input: str,
agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None, llm_id: str = "") -> List[dict]:
if not system_prompt:
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
@ -154,7 +153,11 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
llm = LLMBundle(tenant_id, llm_type, llm_id)
if tenant_llm_id:
llm_config = get_model_config_by_id(tenant_llm_id)
else:
llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id)
llm = LLMBundle(tenant_id, llm_config)
if task_id:
TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."})
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
@ -170,7 +173,11 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
if memory.tenant_embd_id:
embd_model_config = get_model_config_by_id(memory.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
embedding_model = LLMBundle(memory.tenant_id, embd_model_config)
if task_id:
TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."})
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
@ -239,7 +246,11 @@ def query_message(filter_dict: dict, params: dict):
question = params["query"]
question = question.strip()
memory = memory_list[0]
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
if memory.tenant_embd_id:
embd_model_config = get_model_config_by_id(memory.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
embd_model = LLMBundle(memory.tenant_id, embd_model_config)
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
match_text, _ = MsgTextQuery().question(question, min_match=params["similarity_threshold"])
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)

View File

@ -0,0 +1,91 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import enum
from common import settings
from common.constants import LLMType
from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService, TenantService
def get_model_config_by_id(tenant_model_id: int) -> dict:
found, model_config = TenantLLMService.get_by_id(tenant_model_id)
if not found:
raise LookupError(f"Tenant Model with id {tenant_model_id} not found")
config_dict = model_config.to_dict()
llm = LLMService.query(llm_name=config_dict["llm_name"])
if llm:
config_dict["is_tools"] = llm[0].is_tools
return config_dict
def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
if not model_name:
raise Exception("Model Name is required")
model_config = TenantLLMService.get_api_key(tenant_id, model_name)
if not model_config:
# model_name in format 'name@factory', split model_name and try again
pure_model_name, fid = TenantLLMService.split_model_name_and_factory(model_name)
if model_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and pure_model_name == os.getenv("TEI_MODEL", ""):
# configured local embedding model
embedding_cfg = settings.EMBEDDING_CFG
config_dict = {
"llm_factory": "Builtin",
"api_key": embedding_cfg["api_key"],
"llm_name": pure_model_name,
"api_base": embedding_cfg["base_url"],
"model_type": LLMType.EMBEDDING,
}
else:
model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name)
if not model_config:
raise LookupError(f"Tenant Model with name {model_name} not found")
config_dict = model_config.to_dict()
else:
# model_name without @factory
config_dict = model_config.to_dict()
llm = LLMService.query(llm_name=config_dict["llm_name"])
if llm:
config_dict["is_tools"] = llm[0].is_tools
return config_dict
def get_tenant_default_model_by_type(tenant_id: str, model_type: str|enum.Enum):
exist, tenant = TenantService.get_by_id(tenant_id)
if not exist:
raise LookupError("Tenant not found")
model_type_val = model_type if isinstance(model_type, str) else model_type.value
model_name: str = ""
match model_type_val:
case LLMType.EMBEDDING.value:
model_name = tenant.embd_id
case LLMType.SPEECH2TEXT.value:
model_name = tenant.asr_id
case LLMType.IMAGE2TEXT.value:
model_name = tenant.img2txt_id
case LLMType.CHAT.value:
model_name = tenant.llm_id
case LLMType.RERANK.value:
model_name = tenant.rerank_id
case LLMType.TTS.value:
model_name = tenant.tts_id
case LLMType.OCR.value:
raise Exception("OCR model name is required")
case _:
raise Exception(f"Unknown model type {model_type}")
if not model_name:
raise Exception(f"No default {model_type} model is set.")
return get_model_config_by_type_and_name(tenant_id, model_type, model_name)

View File

@ -34,6 +34,7 @@ from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from common.time_utils import current_timestamp, datetime_format
from common.text_utils import normalize_arabic_digits
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
@ -179,6 +180,28 @@ class DialogService(CommonService):
offset += limit
return res
@classmethod
@DB.connection_context()
def get_null_tenant_llm_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.llm_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
return list(objs)
@classmethod
@DB.connection_context()
def get_null_tenant_rerank_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.rerank_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null())
return list(objs)
async def async_chat_solo(dialog, messages, stream=True):
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
@ -191,22 +214,15 @@ async def async_chat_solo(dialog, messages, stream=True):
else:
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
attachments = "\n\n".join(text_attachments)
if llm_type == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
if llm_type == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
model_config = get_model_config_by_id(dialog.tenant_llm_id)
chat_mdl = LLMBundle(dialog.tenant_id, model_config)
factory = model_config.get("llm_factory", "") if model_config else ""
prompt_config = dialog.prompt_config
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
default_tts_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
if attachments and msg:
msg[-1]["content"] += attachments
@ -241,20 +257,27 @@ def get_models(dialog):
raise Exception("**ERROR**: Knowledge bases use different embedding models.")
if embedding_list:
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
embd_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
embd_mdl = LLMBundle(dialog.tenant_id, embd_model_config)
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
if dialog.tenant_llm_id:
chat_model_config = get_model_config_by_id(dialog.tenant_llm_id)
elif dialog.llm_id:
chat_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config)
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
rerank_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config)
if dialog.prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
default_tts_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config)
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
@ -603,8 +626,9 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
LLMBundle(dialog.tenant_id, default_chat_model))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
@ -1341,11 +1365,13 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embedding_list[0])
embd_mdl = LLMBundle(tenant_id, embd_model_config)
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_llm_name)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
@ -1417,13 +1443,22 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
kbs = KnowledgebaseService.get_by_ids(kb_ids)
if not kbs:
return {"error": "No KB selected"}
embedding_list = list(set([kb.embd_id for kb in kbs]))
tenant_embedding_list = list(set([kb.tenant_embd_id for kb in kbs]))
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
if tenant_embedding_list[0]:
embd_model_config = get_model_config_by_id(tenant_embedding_list[0])
else:
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, kbs[0].embd_id)
embd_mdl = LLMBundle(tenant_id, embd_model_config)
chat_id = search_config.get("chat_id", "")
if chat_id:
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id)
else:
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)

View File

@ -661,6 +661,19 @@ class DocumentService(CommonService):
return None
return docs[0]["embd_id"]
@classmethod
@DB.connection_context()
def get_tenant_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
return docs[0]["tenant_embd_id"]
@classmethod
@DB.connection_context()
def get_chunking_config(cls, doc_id):
@ -1007,6 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from rag.app import audio, email, naive, picture, presentation
e, conv = ConversationService.get_by_id(conversation_id)
@ -1022,8 +1036,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this dataset!")
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, lang=kb.language)
err, files = FileService.upload_document(kb, file_objs, user_id)
assert not err, "\n".join(err)
@ -1101,7 +1118,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
try_create_idx = True
_, tenant = TenantService.get_by_id(kb.tenant_id)
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
tenant_llm_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
llm_bdl = LLMBundle(kb.tenant_id, tenant_llm_config)
for doc_id in docids:
cks = [c for c in docs if c["doc_id"] == doc_id]

View File

@ -564,3 +564,14 @@ class KnowledgebaseService(CommonService):
'update_date': datetime_format(datetime.now())
}
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
@classmethod
@DB.connection_context()
def get_null_tenant_embd_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.embd_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null())
return list(objs)

View File

@ -83,18 +83,18 @@ def get_init_tenant_llm(user_id):
class LLMBundle(LLM4Tenant):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
super().__init__(tenant_id, model_config, lang, **kwargs)
def bind_tools(self, toolcall_session, tools):
if not self.is_tools:
logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
logging.warning(f"Model {self.model_config['llm_name']} does not support tool call, but you have assigned one or more tools to it!")
return
self.mdl.bind_tools(toolcall_session, tools)
def encode(self, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.model_config["llm_name"], input={"texts": texts})
safe_texts = []
for text in texts:
@ -106,9 +106,9 @@ class LLMBundle(LLM4Tenant):
safe_texts.append(text)
embeddings, used_tokens = self.mdl.encode(safe_texts)
llm_name = getattr(self, "llm_name", None)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
if self.model_config["llm_factory"] == "Builtin":
logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(texts, len(embeddings), used_tokens))
elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.encode can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
if self.langfuse:
@ -119,11 +119,12 @@ class LLMBundle(LLM4Tenant):
def encode_queries(self, query: str):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
emd, used_tokens = self.mdl.encode_queries(query)
llm_name = getattr(self, "llm_name", None)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
if self.model_config["llm_factory"] == "Builtin":
logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(query, len(emd), used_tokens))
elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.encode_queries can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
if self.langfuse:
@ -134,10 +135,10 @@ class LLMBundle(LLM4Tenant):
def similarity(self, query: str, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
sim, used_tokens = self.mdl.similarity(query, texts)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
@ -148,10 +149,10 @@ class LLMBundle(LLM4Tenant):
def describe(self, image, max_tokens=300):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.model_config["llm_name"]})
txt, used_tokens = self.mdl.describe(image)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
@ -162,10 +163,10 @@ class LLMBundle(LLM4Tenant):
def describe_with_prompt(self, image, prompt):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
@ -176,10 +177,10 @@ class LLMBundle(LLM4Tenant):
def transcription(self, audio):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.model_config["llm_name"]})
txt, used_tokens = self.mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
@ -196,7 +197,7 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name},
metadata={"model": self.model_config["llm_name"]},
)
final_text = ""
used_tokens = 0
@ -215,7 +216,7 @@ class LLMBundle(LLM4Tenant):
finally:
if final_text:
used_tokens = num_tokens_from_string(final_text)
TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens)
TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens)
if self.langfuse:
generation.update(
@ -230,11 +231,11 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name},
metadata={"model": self.model_config["llm_name"]},
)
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
if self.langfuse:
@ -256,7 +257,7 @@ class LLMBundle(LLM4Tenant):
for chunk in self.mdl.tts(text):
if isinstance(chunk, int):
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], chunk):
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
@ -375,7 +376,7 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
@ -392,8 +393,8 @@ class LLMBundle(LLM4Tenant):
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if used_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], used_tokens))
if generation:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
@ -413,7 +414,7 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
@ -437,8 +438,8 @@ class LLMBundle(LLM4Tenant):
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
@ -456,7 +457,7 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
@ -480,8 +481,8 @@ class LLMBundle(LLM4Tenant):
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()

View File

@ -107,7 +107,7 @@ class MemoryService(CommonService):
@classmethod
@DB.connection_context()
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, llm_id: str):
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, tenant_embd_id: int, llm_id: str, tenant_llm_id: int):
# Deduplicate name within tenant
memory_name = duplicate_name(
cls.query,
@ -126,7 +126,9 @@ class MemoryService(CommonService):
"memory_type": calculate_memory_type(memory_type),
"tenant_id": tenant_id,
"embd_id": embd_id,
"tenant_embd_id": tenant_embd_id,
"llm_id": llm_id,
"tenant_llm_id": tenant_llm_id,
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
"create_time": timestamp,
"create_date": format_time,
@ -168,3 +170,25 @@ class MemoryService(CommonService):
@DB.connection_context()
def delete_memory(cls, memory_id: str):
return cls.delete_by_id(memory_id)
@classmethod
@DB.connection_context()
def get_null_tenant_embd_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.embd_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null())
return list(objs)
@classmethod
@DB.connection_context()
def get_null_tenant_llm_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.llm_id
]
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
return list(objs)

View File

@ -60,7 +60,7 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
fields = [cls.model.id, cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@ -133,34 +133,35 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
def model_instance(cls, model_config: dict, lang="Chinese", **kwargs):
if not model_config:
raise LookupError("Model config is required")
kwargs.update({"provider": model_config["llm_factory"]})
if llm_type == LLMType.EMBEDDING.value:
if model_config["model_type"] == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return None
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
elif llm_type == LLMType.RERANK:
elif model_config["model_type"] == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return None
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
elif llm_type == LLMType.IMAGE2TEXT.value:
elif model_config["model_type"] == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return None
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
elif llm_type == LLMType.CHAT.value:
elif model_config["model_type"] == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return None
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
elif llm_type == LLMType.SPEECH2TEXT:
elif model_config["model_type"] == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return None
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
elif llm_type == LLMType.TTS:
elif model_config["model_type"] == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return None
return TTSModel[model_config["llm_factory"]](
@ -169,7 +170,7 @@ class TenantLLMService(CommonService):
base_url=model_config["api_base"],
)
elif llm_type == LLMType.OCR:
elif model_config["model_type"] == LLMType.OCR:
if model_config["llm_factory"] not in OcrModel:
return None
return OcrModel[model_config["llm_factory"]](
@ -218,6 +219,16 @@ class TenantLLMService(CommonService):
return num
@classmethod
@DB.connection_context()
def increase_usage_by_id(cls, tenant_model_id: int, used_tokens: int):
try:
update_cnt = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens).where(cls.model.id == tenant_model_id).execute()
except Exception as e:
logging.exception(f"TenantLLMService.increase_usage got exception {e}, Failed to update used_tokens for tenant_model_id {tenant_model_id}")
return 0
return update_cnt
@classmethod
@DB.connection_context()
def get_openai_models(cls):
@ -376,13 +387,12 @@ class TenantLLMService(CommonService):
class LLM4Tenant:
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.llm_name = model_config["llm_name"]
self.model_config = model_config
self.mdl = TenantLLMService.model_instance(model_config, lang=lang, **kwargs)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, model_config["llm_type"], model_config["llm_name"])
self.max_length = model_config.get("max_tokens", 8192)
self.is_tools = model_config.get("is_tools", False)

View File

@ -226,6 +226,12 @@ class TenantService(CommonService):
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
return int(hash_obj.hexdigest(), 16)%len(settings.MINIO)
@classmethod
@DB.connection_context()
def get_null_tenant_model_id_rows(cls):
objs = cls.model.select().orwhere(cls.model.tenant_llm_id.is_null(), cls.model.tenant_embd_id.is_null(), cls.model.tenant_asr_id.is_null(), cls.model.tenant_tts_id.is_null(), cls.model.tenant_rerank_id.is_null(), cls.model.tenant_img2txt_id.is_null())
return list(objs)
class UserTenantService(CommonService):
"""Service class for managing user-tenant relationship operations.