mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-05 09:47:47 +08:00
Feat/tenant model (#13072)
### What problem does this PR solve? Add id for table tenant_llm and apply in LLMBundle. ### Type of change - [x] Refactoring --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
@ -43,6 +43,7 @@ from peewee import (
|
||||
Metadata,
|
||||
Model,
|
||||
TextField,
|
||||
PrimaryKeyField,
|
||||
)
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
@ -737,11 +738,17 @@ class Tenant(DataBaseModel):
|
||||
name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
|
||||
public_key = CharField(max_length=255, null=True, index=True)
|
||||
llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
|
||||
tenant_llm_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID", index=True)
|
||||
tenant_embd_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID", index=True)
|
||||
tenant_asr_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID", index=True)
|
||||
tenant_img2txt_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
rerank_id = CharField(max_length=128, null=False, help_text="default rerank model ID", index=True)
|
||||
tenant_rerank_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
tts_id = CharField(max_length=256, null=True, help_text="default tts model ID", index=True)
|
||||
tenant_tts_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
parser_ids = CharField(max_length=256, null=False, help_text="document processors", index=True)
|
||||
credit = IntegerField(default=512, index=True)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
@ -808,14 +815,15 @@ class LLM(DataBaseModel):
|
||||
|
||||
|
||||
class TenantLLM(DataBaseModel):
|
||||
id = PrimaryKeyField()
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name", index=True)
|
||||
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
|
||||
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="", index=True)
|
||||
api_key = TextField(null=True, help_text="API KEY")
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
max_tokens = IntegerField(default=8192, index=True)
|
||||
used_tokens = IntegerField(default=0, index=True)
|
||||
max_tokens = IntegerField(default=8192, help_text="Max context token num", index=True)
|
||||
used_tokens = IntegerField(default=0, help_text="Used token num", index=True)
|
||||
status = CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
def __str__(self):
|
||||
@ -823,7 +831,9 @@ class TenantLLM(DataBaseModel):
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant_llm"
|
||||
primary_key = CompositeKey("tenant_id", "llm_factory", "llm_name")
|
||||
indexes = (
|
||||
(("tenant_id", "llm_factory", "llm_name"), True),
|
||||
)
|
||||
|
||||
|
||||
class TenantLangfuse(DataBaseModel):
|
||||
@ -847,6 +857,7 @@ class Knowledgebase(DataBaseModel):
|
||||
language = CharField(max_length=32, null=True, default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", help_text="English|Chinese", index=True)
|
||||
description = TextField(null=True, help_text="KB description")
|
||||
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID", index=True)
|
||||
tenant_embd_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
|
||||
created_by = CharField(max_length=32, null=False, index=True)
|
||||
doc_num = IntegerField(default=0, index=True)
|
||||
@ -954,6 +965,7 @@ class Dialog(DataBaseModel):
|
||||
icon = TextField(null=True, help_text="icon base64 string")
|
||||
language = CharField(max_length=32, null=True, default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", help_text="English|Chinese", index=True)
|
||||
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
||||
tenant_llm_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
|
||||
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, "presence_penalty": 0.4, "max_tokens": 512})
|
||||
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced", index=True)
|
||||
@ -973,7 +985,7 @@ class Dialog(DataBaseModel):
|
||||
do_refer = CharField(max_length=1, null=False, default="1", help_text="it needs to insert reference index into answer or not")
|
||||
|
||||
rerank_id = CharField(max_length=128, null=False, help_text="default rerank model ID")
|
||||
|
||||
tenant_rerank_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
kb_ids = JSONField(null=False, default=[])
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
@ -1295,7 +1307,9 @@ class Memory(DataBaseModel):
|
||||
memory_type = IntegerField(null=False, default=1, index=True, help_text="Bit flags (LSB->MSB): 1=raw, 2=semantic, 4=episodic, 8=procedural. E.g., 5 enables raw + episodic.")
|
||||
storage_type = CharField(max_length=32, default='table', null=False, index=True, help_text="table|graph")
|
||||
embd_id = CharField(max_length=128, null=False, index=False, help_text="embedding model ID")
|
||||
tenant_embd_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
llm_id = CharField(max_length=128, null=False, index=False, help_text="chat model ID")
|
||||
tenant_llm_id = IntegerField(null=True, help_text="id in tenant_llm", index=True)
|
||||
permissions = CharField(max_length=16, null=False, index=True, help_text="me|team", default="me")
|
||||
description = TextField(null=True, help_text="description")
|
||||
memory_size = IntegerField(default=5242880, null=False, index=False)
|
||||
@ -1351,6 +1365,23 @@ def alter_db_rename_column(migrator, table_name, old_column_name, new_column_nam
|
||||
|
||||
def migrate_add_unique_email(migrator):
|
||||
"""Deduplicates user emails and add UNIQUE constraint to email column (idempotent)"""
|
||||
# step 0: check if UNIQUE index on email already exists
|
||||
try:
|
||||
cursor = DB.execute_sql("""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.statistics
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = 'user'
|
||||
AND index_name = 'user_email'
|
||||
AND non_unique = 0
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
if result and result[0] > 0:
|
||||
logging.info("UNIQUE index on user.email already exists, skipping migration")
|
||||
return
|
||||
except Exception as ex:
|
||||
logging.warning("Failed to check if UNIQUE index exists on user.email: %s, continuing with migration", ex)
|
||||
|
||||
# step 1: rename duplicate rows so the UNIQUE constraint can be applied
|
||||
try:
|
||||
duplicates = User.select(User.email).group_by(User.email).having(fn.COUNT(User.id) > 1).tuples()
|
||||
@ -1385,6 +1416,62 @@ def migrate_add_unique_email(migrator):
|
||||
logging.critical("Failed to add UNIQUE constraint on user.email: %s", ex)
|
||||
|
||||
|
||||
|
||||
def update_tenant_llm_to_id_primary_key():
|
||||
"""Add ID and set to primary key step by step."""
|
||||
try:
|
||||
with DB.atomic():
|
||||
# 0. Check if exist ID
|
||||
cursor = DB.execute_sql("""
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = 'tenant_llm'
|
||||
AND COLUMN_NAME = 'id'
|
||||
""")
|
||||
if cursor.rowcount > 0:
|
||||
return
|
||||
|
||||
# 1. Add nullable column
|
||||
DB.execute_sql("ALTER TABLE tenant_llm ADD COLUMN temp_id INT NULL")
|
||||
|
||||
# 2. Set ID
|
||||
DB.execute_sql("SET @row = 0;")
|
||||
DB.execute_sql("UPDATE tenant_llm SET temp_id = (@row := @row + 1) ORDER BY tenant_id, llm_factory, llm_name;")
|
||||
|
||||
# 3. Drop old primary key
|
||||
DB.execute_sql("ALTER TABLE tenant_llm DROP PRIMARY KEY")
|
||||
|
||||
# 4. Update ID column to primary key
|
||||
DB.execute_sql("""
|
||||
ALTER TABLE tenant_llm
|
||||
MODIFY COLUMN temp_id INT NOT NULL AUTO_INCREMENT PRIMARY KEY
|
||||
""")
|
||||
|
||||
# 5. Add unique key
|
||||
DB.execute_sql("""
|
||||
ALTER TABLE tenant_llm
|
||||
ADD CONSTRAINT uk_tenant_llm UNIQUE (tenant_id, llm_factory, llm_name)
|
||||
""")
|
||||
|
||||
# 6. rename
|
||||
DB.execute_sql("ALTER TABLE tenant_llm RENAME COLUMN temp_id TO id")
|
||||
|
||||
logging.info("Successfully updated tenant_llm to id primary key.")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
cursor = DB.execute_sql("""
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = 'tenant_llm'
|
||||
AND COLUMN_NAME = 'temp_id'
|
||||
""")
|
||||
if cursor.rowcount > 0:
|
||||
DB.execute_sql("ALTER TABLE tenant_llm DROP COLUMN temp_id")
|
||||
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
@ -1436,6 +1523,18 @@ def migrate_db():
|
||||
alter_db_add_column(migrator, "api_4_conversation", "exp_user_id", CharField(max_length=255, null=True, help_text="exp_user_id", index=True))
|
||||
# Migrate system_settings.value from CharField to TextField for longer sandbox configs
|
||||
alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)"))
|
||||
update_tenant_llm_to_id_primary_key()
|
||||
alter_db_add_column(migrator, "tenant", "tenant_llm_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "tenant", "tenant_embd_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "tenant", "tenant_asr_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "tenant", "tenant_img2txt_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "tenant", "tenant_rerank_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "tenant", "tenant_tts_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "knowledgebase", "tenant_embd_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "dialog", "tenant_llm_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "dialog", "tenant_rerank_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "memory", "tenant_embd_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
alter_db_add_column(migrator, "memory", "tenant_llm_id", IntegerField(null=True, help_text="id in tenant_llm", index=True))
|
||||
logging.disable(logging.NOTSET)
|
||||
# this is after re-enabling logging to allow logging changed user emails
|
||||
migrate_add_unique_email(migrator)
|
||||
|
||||
@ -24,16 +24,19 @@ from copy import deepcopy
|
||||
from peewee import IntegrityError
|
||||
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
|
||||
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM, Knowledgebase, Dialog, Memory
|
||||
from api.db.services import UserService
|
||||
from api.db.services.canvas_service import CanvasTemplateService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.db.services.system_settings_service import SystemSettingsService
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache, fix_missing_tokenized_memory
|
||||
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type
|
||||
from common.constants import LLMType
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
@ -90,13 +93,15 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
|
||||
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
|
||||
|
||||
if tenant["llm_id"]:
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant["id"], LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant["id"], chat_model_config)
|
||||
msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}))
|
||||
if msg.find("ERROR: ") == 0:
|
||||
logging.error("'{}' doesn't work. {}".format( tenant["llm_id"], msg))
|
||||
|
||||
if tenant["embd_id"]:
|
||||
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
|
||||
embd_model_config = get_tenant_default_model_by_type(tenant["id"], LLMType.EMBEDDING)
|
||||
embd_mdl = LLMBundle(tenant["id"], embd_model_config)
|
||||
v, c = embd_mdl.encode(["Hello!"])
|
||||
if c == 0:
|
||||
logging.error("'{}' doesn't work!".format(tenant["embd_id"]))
|
||||
@ -185,6 +190,7 @@ def init_web_data():
|
||||
init_message_id_sequence()
|
||||
init_memory_size_cache()
|
||||
fix_missing_tokenized_memory()
|
||||
fix_empty_tenant_model_id()
|
||||
logging.info("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
def init_table():
|
||||
@ -213,6 +219,105 @@ def init_table():
|
||||
raise e
|
||||
|
||||
|
||||
def fix_empty_tenant_model_id():
|
||||
# knowledgebase
|
||||
empty_tenant_embd_id_kbs = KnowledgebaseService.get_null_tenant_embd_id_row()
|
||||
if empty_tenant_embd_id_kbs:
|
||||
logging.info(f"Found {len(empty_tenant_embd_id_kbs)} empty tenant_embd_id knowledgebase.")
|
||||
kb_groups: dict = {}
|
||||
for obj in empty_tenant_embd_id_kbs:
|
||||
if kb_groups.get((obj.tenant_id, obj.embd_id)):
|
||||
kb_groups[(obj.tenant_id, obj.embd_id)].append(obj.id)
|
||||
else:
|
||||
kb_groups[(obj.tenant_id, obj.embd_id)] = [obj.id]
|
||||
update_cnt = 0
|
||||
for k, v in kb_groups.items():
|
||||
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
|
||||
if tenant_llm:
|
||||
update_cnt += KnowledgebaseService.filter_update([Knowledgebase.id.in_(v)], {"tenant_embd_id": tenant_llm.id})
|
||||
logging.info(f"Update {update_cnt} tenant_embd_id in table knowledgebase.")
|
||||
# dialog
|
||||
empty_tenant_llm_id_dialog = DialogService.get_null_tenant_llm_id_row()
|
||||
if empty_tenant_llm_id_dialog:
|
||||
logging.info(f"Found {len(empty_tenant_llm_id_dialog)} empty tenant_llm_id dialogs.")
|
||||
dialog_groups: dict = {}
|
||||
for obj in empty_tenant_llm_id_dialog:
|
||||
if dialog_groups.get((obj.tenant_id, obj.llm_id)):
|
||||
dialog_groups[(obj.tenant_id, obj.llm_id)].append(obj.id)
|
||||
else:
|
||||
dialog_groups[(obj.tenant_id, obj.llm_id)] = [obj.id]
|
||||
update_cnt = 0
|
||||
for k, v in dialog_groups.items():
|
||||
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
|
||||
if tenant_llm:
|
||||
update_cnt += DialogService.filter_update([Dialog.id.in_(v)], {"tenant_llm_id": tenant_llm.id})
|
||||
logging.info(f"Update {update_cnt} tenant_llm_id in table dialog.")
|
||||
|
||||
empty_tenant_rerank_id_dialog = DialogService.get_null_tenant_rerank_id_row()
|
||||
if empty_tenant_rerank_id_dialog:
|
||||
logging.info(f"Found {len(empty_tenant_rerank_id_dialog)} empty tenant_rerank_id dialogs.")
|
||||
dialog_groups: dict = {}
|
||||
for obj in empty_tenant_rerank_id_dialog:
|
||||
if dialog_groups.get((obj.tenant_id, obj.rerank_id)):
|
||||
dialog_groups[(obj.tenant_id, obj.rerank_id)].append(obj.id)
|
||||
else:
|
||||
dialog_groups[(obj.tenant_id, obj.rerank_id)] = [obj.id]
|
||||
update_cnt = 0
|
||||
for k, v in dialog_groups.items():
|
||||
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
|
||||
if tenant_llm:
|
||||
update_cnt += DialogService.filter_update([Dialog.id.in_(v)], {"tenant_rerank_id": tenant_llm.id})
|
||||
logging.info(f"Update {update_cnt} tenant_rerank_id in table dialog.")
|
||||
# memory
|
||||
empty_tenant_embd_id_memories = MemoryService.get_null_tenant_embd_id_row()
|
||||
if empty_tenant_embd_id_memories:
|
||||
logging.info(f"Found {len(empty_tenant_embd_id_memories)} empty tenant_embd_id memories.")
|
||||
memory_groups: dict = {}
|
||||
for obj in empty_tenant_embd_id_memories:
|
||||
if memory_groups.get((obj.tenant_id, obj.embd_id)):
|
||||
memory_groups[(obj.tenant_id, obj.embd_id)].append(obj.id)
|
||||
else:
|
||||
memory_groups[(obj.tenant_id, obj.embd_id)] = [obj.id]
|
||||
update_cnt = 0
|
||||
for k, v in memory_groups.items():
|
||||
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
|
||||
if tenant_llm:
|
||||
update_cnt += MemoryService.filter_update([Memory.id.in_(v)], {"tenant_embd_id": tenant_llm.id})
|
||||
logging.info(f"Update {update_cnt} tenant_embd_id in table memory.")
|
||||
|
||||
empty_tenant_llm_id_memories = MemoryService.get_null_tenant_llm_id_row()
|
||||
if empty_tenant_llm_id_memories:
|
||||
logging.info(f"Found {len(empty_tenant_llm_id_memories)} empty tenant_llm_id memories.")
|
||||
memory_groups: dict = {}
|
||||
for obj in empty_tenant_llm_id_memories:
|
||||
if memory_groups.get((obj.tenant_id, obj.llm_id)):
|
||||
memory_groups[(obj.tenant_id, obj.llm_id)].append(obj.id)
|
||||
else:
|
||||
memory_groups[(obj.tenant_id, obj.llm_id)] = [obj.id]
|
||||
update_cnt = 0
|
||||
for k, v in memory_groups.items():
|
||||
tenant_llm = TenantLLMService.get_api_key(k[0], k[1])
|
||||
if tenant_llm:
|
||||
update_cnt += MemoryService.filter_update([Memory.id.in_(v)], {"tenant_llm_id": tenant_llm.id})
|
||||
logging.info(f"Update {update_cnt} tenant_llm_id in table memory.")
|
||||
# tenant
|
||||
empty_tenant_model_id_tenants = TenantService.get_null_tenant_model_id_rows()
|
||||
if empty_tenant_model_id_tenants:
|
||||
logging.info(f"Found {len(empty_tenant_model_id_tenants)} empty tenant_model_id tenants.")
|
||||
update_cnt = 0
|
||||
for obj in empty_tenant_model_id_tenants:
|
||||
tenant_dict = obj.to_dict()
|
||||
update_dict = {}
|
||||
for key in ["llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"]:
|
||||
if tenant_dict.get(key) and not tenant_dict.get(f"tenant_{key}"):
|
||||
tenant_model = TenantLLMService.get_api_key(tenant_dict["id"], tenant_dict[key])
|
||||
if tenant_model:
|
||||
update_dict.update({f"tenant_{key}": tenant_model.id})
|
||||
if update_dict:
|
||||
update_dict += TenantService.update_by_id(tenant_dict["id"], update_dict)
|
||||
logging.info(f"Update {update_cnt} tenant_model_id in table tenant.")
|
||||
logging.info("Fix empty tenant_model_id done.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_web_db()
|
||||
init_web_data()
|
||||
|
||||
@ -25,8 +25,8 @@ from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.db_models import Task
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name
|
||||
from api.utils.memory_utils import get_memory_type_human
|
||||
from memory.services.messages import MessageService
|
||||
from memory.services.query import MsgTextQuery, get_vector
|
||||
@ -53,11 +53,12 @@ async def save_to_memory(memory_id: str, message_dict: dict):
|
||||
tenant_id = memory.tenant_id
|
||||
extracted_content = await extract_by_llm(
|
||||
tenant_id,
|
||||
memory.llm_id,
|
||||
memory.tenant_llm_id,
|
||||
{"temperature": memory.temperature},
|
||||
get_memory_type_human(memory.memory_type),
|
||||
message_dict.get("user_input", ""),
|
||||
message_dict.get("agent_response", "")
|
||||
message_dict.get("agent_response", ""),
|
||||
llm_id=memory.llm_id
|
||||
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
|
||||
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
|
||||
message_list = [{
|
||||
@ -107,12 +108,13 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
|
||||
tenant_id = memory.tenant_id
|
||||
extracted_content = await extract_by_llm(
|
||||
tenant_id,
|
||||
memory.llm_id,
|
||||
memory.tenant_llm_id,
|
||||
{"temperature": memory.temperature},
|
||||
get_memory_type_human(memory.memory_type),
|
||||
message_dict.get("user_input", ""),
|
||||
message_dict.get("agent_response", ""),
|
||||
task_id=task_id
|
||||
task_id=task_id,
|
||||
llm_id=memory.llm_id
|
||||
)
|
||||
message_list = [{
|
||||
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||
@ -139,11 +141,8 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
|
||||
return await embed_and_save(memory, message_list, task_id)
|
||||
|
||||
|
||||
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||
agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None) -> List[dict]:
|
||||
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
|
||||
if not llm_type:
|
||||
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
|
||||
async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||
agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None, llm_id: str = "") -> List[dict]:
|
||||
if not system_prompt:
|
||||
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
|
||||
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
|
||||
@ -154,7 +153,11 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
|
||||
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
|
||||
else:
|
||||
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
||||
llm = LLMBundle(tenant_id, llm_type, llm_id)
|
||||
if tenant_llm_id:
|
||||
llm_config = get_model_config_by_id(tenant_llm_id)
|
||||
else:
|
||||
llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id)
|
||||
llm = LLMBundle(tenant_id, llm_config)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."})
|
||||
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
|
||||
@ -170,7 +173,11 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
|
||||
|
||||
|
||||
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
|
||||
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
if memory.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(memory.tenant_embd_id)
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
|
||||
embedding_model = LLMBundle(memory.tenant_id, embd_model_config)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."})
|
||||
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||
@ -239,7 +246,11 @@ def query_message(filter_dict: dict, params: dict):
|
||||
question = params["query"]
|
||||
question = question.strip()
|
||||
memory = memory_list[0]
|
||||
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
if memory.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(memory.tenant_embd_id)
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
|
||||
embd_model = LLMBundle(memory.tenant_id, embd_model_config)
|
||||
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
|
||||
match_text, _ = MsgTextQuery().question(question, min_match=params["similarity_threshold"])
|
||||
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
|
||||
|
||||
91
api/db/joint_services/tenant_model_service.py
Normal file
91
api/db/joint_services/tenant_model_service.py
Normal file
@ -0,0 +1,91 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import enum
|
||||
from common import settings
|
||||
from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService, TenantService
|
||||
|
||||
|
||||
def get_model_config_by_id(tenant_model_id: int) -> dict:
|
||||
found, model_config = TenantLLMService.get_by_id(tenant_model_id)
|
||||
if not found:
|
||||
raise LookupError(f"Tenant Model with id {tenant_model_id} not found")
|
||||
config_dict = model_config.to_dict()
|
||||
llm = LLMService.query(llm_name=config_dict["llm_name"])
|
||||
if llm:
|
||||
config_dict["is_tools"] = llm[0].is_tools
|
||||
return config_dict
|
||||
|
||||
|
||||
def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
|
||||
if not model_name:
|
||||
raise Exception("Model Name is required")
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, model_name)
|
||||
if not model_config:
|
||||
# model_name in format 'name@factory', split model_name and try again
|
||||
pure_model_name, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
if model_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and pure_model_name == os.getenv("TEI_MODEL", ""):
|
||||
# configured local embedding model
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
config_dict = {
|
||||
"llm_factory": "Builtin",
|
||||
"api_key": embedding_cfg["api_key"],
|
||||
"llm_name": pure_model_name,
|
||||
"api_base": embedding_cfg["base_url"],
|
||||
"model_type": LLMType.EMBEDDING,
|
||||
}
|
||||
else:
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name)
|
||||
if not model_config:
|
||||
raise LookupError(f"Tenant Model with name {model_name} not found")
|
||||
config_dict = model_config.to_dict()
|
||||
else:
|
||||
# model_name without @factory
|
||||
config_dict = model_config.to_dict()
|
||||
llm = LLMService.query(llm_name=config_dict["llm_name"])
|
||||
if llm:
|
||||
config_dict["is_tools"] = llm[0].is_tools
|
||||
return config_dict
|
||||
|
||||
|
||||
def get_tenant_default_model_by_type(tenant_id: str, model_type: str|enum.Enum):
|
||||
exist, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not exist:
|
||||
raise LookupError("Tenant not found")
|
||||
model_type_val = model_type if isinstance(model_type, str) else model_type.value
|
||||
model_name: str = ""
|
||||
match model_type_val:
|
||||
case LLMType.EMBEDDING.value:
|
||||
model_name = tenant.embd_id
|
||||
case LLMType.SPEECH2TEXT.value:
|
||||
model_name = tenant.asr_id
|
||||
case LLMType.IMAGE2TEXT.value:
|
||||
model_name = tenant.img2txt_id
|
||||
case LLMType.CHAT.value:
|
||||
model_name = tenant.llm_id
|
||||
case LLMType.RERANK.value:
|
||||
model_name = tenant.rerank_id
|
||||
case LLMType.TTS.value:
|
||||
model_name = tenant.tts_id
|
||||
case LLMType.OCR.value:
|
||||
raise Exception("OCR model name is required")
|
||||
case _:
|
||||
raise Exception(f"Unknown model type {model_type}")
|
||||
if not model_name:
|
||||
raise Exception(f"No default {model_type} model is set.")
|
||||
return get_model_config_by_type_and_name(tenant_id, model_type, model_name)
|
||||
@ -34,6 +34,7 @@ from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from common.text_utils import normalize_arabic_digits
|
||||
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
@ -179,6 +180,28 @@ class DialogService(CommonService):
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_llm_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.llm_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_rerank_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.rerank_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
|
||||
async def async_chat_solo(dialog, messages, stream=True):
|
||||
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
|
||||
@ -191,22 +214,15 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
else:
|
||||
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
|
||||
attachments = "\n\n".join(text_attachments)
|
||||
|
||||
if llm_type == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
|
||||
|
||||
if llm_type == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
model_config = get_model_config_by_id(dialog.tenant_llm_id)
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, model_config)
|
||||
factory = model_config.get("llm_factory", "") if model_config else ""
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
tts_mdl = None
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
default_tts_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model)
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
@ -241,20 +257,27 @@ def get_models(dialog):
|
||||
raise Exception("**ERROR**: Knowledge bases use different embedding models.")
|
||||
|
||||
if embedding_list:
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
embd_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, embd_model_config)
|
||||
if not embd_mdl:
|
||||
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
|
||||
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
if dialog.tenant_llm_id:
|
||||
chat_model_config = get_model_config_by_id(dialog.tenant_llm_id)
|
||||
elif dialog.llm_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||||
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config)
|
||||
|
||||
if dialog.rerank_id:
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||
rerank_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config)
|
||||
|
||||
if dialog.prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
default_tts_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config)
|
||||
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
|
||||
|
||||
|
||||
@ -603,8 +626,9 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if prompt_config.get("use_kg"):
|
||||
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||||
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
LLMBundle(dialog.tenant_id, default_chat_model))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
@ -1341,11 +1365,13 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
|
||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
embd_mdl = LLMBundle(tenant_id, embd_model_config)
|
||||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
if rerank_id:
|
||||
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
|
||||
max_tokens = chat_mdl.max_length
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
|
||||
@ -1417,13 +1443,22 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
if not kbs:
|
||||
return {"error": "No KB selected"}
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
tenant_embedding_list = list(set([kb.tenant_embd_id for kb in kbs]))
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
if tenant_embedding_list[0]:
|
||||
embd_model_config = get_model_config_by_id(tenant_embedding_list[0])
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, kbs[0].embd_id)
|
||||
embd_mdl = LLMBundle(tenant_id, embd_model_config)
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id)
|
||||
else:
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
if rerank_id:
|
||||
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
|
||||
|
||||
@ -661,6 +661,19 @@ class DocumentService(CommonService):
|
||||
return None
|
||||
return docs[0]["embd_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_embd_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_embd_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return None
|
||||
return docs[0]["tenant_embd_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_chunking_config(cls, doc_id):
|
||||
@ -1007,6 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
|
||||
from rag.app import audio, email, naive, picture, presentation
|
||||
|
||||
e, conv = ConversationService.get_by_id(conversation_id)
|
||||
@ -1022,8 +1036,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this dataset!")
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
||||
if kb.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, lang=kb.language)
|
||||
|
||||
err, files = FileService.upload_document(kb, file_objs, user_id)
|
||||
assert not err, "\n".join(err)
|
||||
@ -1101,7 +1118,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
try_create_idx = True
|
||||
|
||||
_, tenant = TenantService.get_by_id(kb.tenant_id)
|
||||
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
|
||||
tenant_llm_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
llm_bdl = LLMBundle(kb.tenant_id, tenant_llm_config)
|
||||
for doc_id in docids:
|
||||
cks = [c for c in docs if c["doc_id"] == doc_id]
|
||||
|
||||
|
||||
@ -564,3 +564,14 @@ class KnowledgebaseService(CommonService):
|
||||
'update_date': datetime_format(datetime.now())
|
||||
}
|
||||
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_embd_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.embd_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
@ -83,18 +83,18 @@ def get_init_tenant_llm(user_id):
|
||||
|
||||
|
||||
class LLMBundle(LLM4Tenant):
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
super().__init__(tenant_id, model_config, lang, **kwargs)
|
||||
|
||||
def bind_tools(self, toolcall_session, tools):
|
||||
if not self.is_tools:
|
||||
logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
|
||||
logging.warning(f"Model {self.model_config['llm_name']} does not support tool call, but you have assigned one or more tools to it!")
|
||||
return
|
||||
self.mdl.bind_tools(toolcall_session, tools)
|
||||
|
||||
def encode(self, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.model_config["llm_name"], input={"texts": texts})
|
||||
|
||||
safe_texts = []
|
||||
for text in texts:
|
||||
@ -106,9 +106,9 @@ class LLMBundle(LLM4Tenant):
|
||||
safe_texts.append(text)
|
||||
|
||||
embeddings, used_tokens = self.mdl.encode(safe_texts)
|
||||
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
||||
if self.model_config["llm_factory"] == "Builtin":
|
||||
logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(texts, len(embeddings), used_tokens))
|
||||
elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.encode can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -119,11 +119,12 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def encode_queries(self, query: str):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
|
||||
|
||||
emd, used_tokens = self.mdl.encode_queries(query)
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
||||
if self.model_config["llm_factory"] == "Builtin":
|
||||
logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(query, len(emd), used_tokens))
|
||||
elif not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.encode_queries can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -134,10 +135,10 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
|
||||
|
||||
sim, used_tokens = self.mdl.similarity(query, texts)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -148,10 +149,10 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.model_config["llm_name"]})
|
||||
|
||||
txt, used_tokens = self.mdl.describe(image)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -162,10 +163,10 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def describe_with_prompt(self, image, prompt):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
|
||||
|
||||
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -176,10 +177,10 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def transcription(self, audio):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.model_config["llm_name"]})
|
||||
|
||||
txt, used_tokens = self.mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -196,7 +197,7 @@ class LLMBundle(LLM4Tenant):
|
||||
generation = self.langfuse.start_generation(
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.llm_name},
|
||||
metadata={"model": self.model_config["llm_name"]},
|
||||
)
|
||||
final_text = ""
|
||||
used_tokens = 0
|
||||
@ -215,7 +216,7 @@ class LLMBundle(LLM4Tenant):
|
||||
finally:
|
||||
if final_text:
|
||||
used_tokens = num_tokens_from_string(final_text)
|
||||
TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens)
|
||||
TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens)
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
@ -230,11 +231,11 @@ class LLMBundle(LLM4Tenant):
|
||||
generation = self.langfuse.start_generation(
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.llm_name},
|
||||
metadata={"model": self.model_config["llm_name"]},
|
||||
)
|
||||
|
||||
full_text, used_tokens = mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
|
||||
|
||||
if self.langfuse:
|
||||
@ -256,7 +257,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
for chunk in self.mdl.tts(text):
|
||||
if isinstance(chunk, int):
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], chunk):
|
||||
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
return
|
||||
yield chunk
|
||||
@ -375,7 +376,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
|
||||
chat_partial = partial(base_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
@ -392,8 +393,8 @@ class LLMBundle(LLM4Tenant):
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
if used_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
|
||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], used_tokens))
|
||||
|
||||
if generation:
|
||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||
@ -413,7 +414,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
@ -437,8 +438,8 @@ class LLMBundle(LLM4Tenant):
|
||||
generation.update(output={"error": str(e)})
|
||||
generation.end()
|
||||
raise
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens))
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
@ -456,7 +457,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
@ -480,8 +481,8 @@ class LLMBundle(LLM4Tenant):
|
||||
generation.update(output={"error": str(e)})
|
||||
generation.end()
|
||||
raise
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens))
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
|
||||
@ -107,7 +107,7 @@ class MemoryService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, llm_id: str):
|
||||
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, tenant_embd_id: int, llm_id: str, tenant_llm_id: int):
|
||||
# Deduplicate name within tenant
|
||||
memory_name = duplicate_name(
|
||||
cls.query,
|
||||
@ -126,7 +126,9 @@ class MemoryService(CommonService):
|
||||
"memory_type": calculate_memory_type(memory_type),
|
||||
"tenant_id": tenant_id,
|
||||
"embd_id": embd_id,
|
||||
"tenant_embd_id": tenant_embd_id,
|
||||
"llm_id": llm_id,
|
||||
"tenant_llm_id": tenant_llm_id,
|
||||
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
|
||||
"create_time": timestamp,
|
||||
"create_date": format_time,
|
||||
@ -168,3 +170,25 @@ class MemoryService(CommonService):
|
||||
@DB.connection_context()
|
||||
def delete_memory(cls, memory_id: str):
|
||||
return cls.delete_by_id(memory_id)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_embd_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.embd_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_llm_id_row(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.llm_id
|
||||
]
|
||||
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
@ -60,7 +60,7 @@ class TenantLLMService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
|
||||
fields = [cls.model.id, cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
@ -133,34 +133,35 @@ class TenantLLMService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
def model_instance(cls, model_config: dict, lang="Chinese", **kwargs):
|
||||
if not model_config:
|
||||
raise LookupError("Model config is required")
|
||||
kwargs.update({"provider": model_config["llm_factory"]})
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["model_type"] == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return None
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
elif llm_type == LLMType.RERANK:
|
||||
elif model_config["model_type"] == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return None
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
elif model_config["model_type"] == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return None
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
elif model_config["model_type"] == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return None
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
elif llm_type == LLMType.SPEECH2TEXT:
|
||||
elif model_config["model_type"] == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return None
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
elif llm_type == LLMType.TTS:
|
||||
elif model_config["model_type"] == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return None
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
@ -169,7 +170,7 @@ class TenantLLMService(CommonService):
|
||||
base_url=model_config["api_base"],
|
||||
)
|
||||
|
||||
elif llm_type == LLMType.OCR:
|
||||
elif model_config["model_type"] == LLMType.OCR:
|
||||
if model_config["llm_factory"] not in OcrModel:
|
||||
return None
|
||||
return OcrModel[model_config["llm_factory"]](
|
||||
@ -218,6 +219,16 @@ class TenantLLMService(CommonService):
|
||||
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increase_usage_by_id(cls, tenant_model_id: int, used_tokens: int):
|
||||
try:
|
||||
update_cnt = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens).where(cls.model.id == tenant_model_id).execute()
|
||||
except Exception as e:
|
||||
logging.exception(f"TenantLLMService.increase_usage got exception {e}, Failed to update used_tokens for tenant_model_id {tenant_model_id}")
|
||||
return 0
|
||||
return update_cnt
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
@ -376,13 +387,12 @@ class TenantLLMService(CommonService):
|
||||
|
||||
|
||||
class LLM4Tenant:
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.llm_name = model_config["llm_name"]
|
||||
self.model_config = model_config
|
||||
self.mdl = TenantLLMService.model_instance(model_config, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, model_config["llm_type"], model_config["llm_name"])
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
self.is_tools = model_config.get("is_tools", False)
|
||||
|
||||
@ -226,6 +226,12 @@ class TenantService(CommonService):
|
||||
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
|
||||
return int(hash_obj.hexdigest(), 16)%len(settings.MINIO)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_null_tenant_model_id_rows(cls):
|
||||
objs = cls.model.select().orwhere(cls.model.tenant_llm_id.is_null(), cls.model.tenant_embd_id.is_null(), cls.model.tenant_asr_id.is_null(), cls.model.tenant_tts_id.is_null(), cls.model.tenant_rerank_id.is_null(), cls.model.tenant_img2txt_id.is_null())
|
||||
return list(objs)
|
||||
|
||||
|
||||
class UserTenantService(CommonService):
|
||||
"""Service class for managing user-tenant relationship operations.
|
||||
|
||||
Reference in New Issue
Block a user