mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-06 02:07:49 +08:00
Feat/tenant model (#13072)
### What problem does this PR solve? Add id for table tenant_llm and apply in LLMBundle. ### Type of change - [x] Refactoring --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
@ -31,6 +31,7 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_model_config_by_id
|
||||
from api.utils.api_utils import (
|
||||
get_error_data_result,
|
||||
server_error_response,
|
||||
@ -44,6 +45,7 @@ from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.tenant_utils import ensure_tenant_model_id_for_params
|
||||
from rag.nlp import search
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
@ -57,11 +59,12 @@ from api.apps import login_required, current_user
|
||||
@validate_request("name")
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
create_dict = ensure_tenant_model_id_for_params(current_user.id, req)
|
||||
e, res = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
name = create_dict.pop("name", None),
|
||||
tenant_id = current_user.id,
|
||||
parser_id = req.pop("parser_id", None),
|
||||
**req
|
||||
parser_id = create_dict.pop("parser_id", None),
|
||||
**create_dict
|
||||
)
|
||||
|
||||
if not e:
|
||||
@ -81,30 +84,31 @@ async def create():
|
||||
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
async def update():
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["name"], str):
|
||||
update_dict = ensure_tenant_model_id_for_params(current_user.id, req)
|
||||
if not isinstance(update_dict["name"], str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
if update_dict["name"].strip() == "":
|
||||
return get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
if len(update_dict["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return get_data_error_result(
|
||||
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
|
||||
req["name"] = req["name"].strip()
|
||||
message=f"Dataset name length is {len(update_dict['name'])} which is large than {DATASET_NAME_LIMIT}")
|
||||
update_dict["name"] = update_dict["name"].strip()
|
||||
if settings.DOC_ENGINE_INFINITY:
|
||||
parser_id = req.get("parser_id")
|
||||
parser_id = update_dict.get("parser_id")
|
||||
if isinstance(parser_id, str) and parser_id.lower() == "tag":
|
||||
return get_json_result(
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
message="The chunking method Tag has not been supported by Infinity yet.",
|
||||
data=False,
|
||||
)
|
||||
if "pagerank" in req and req["pagerank"] > 0:
|
||||
if "pagerank" in update_dict and update_dict["pagerank"] > 0:
|
||||
return get_json_result(
|
||||
code=RetCode.DATA_ERROR,
|
||||
message="'pagerank' can only be set when doc_engine is elasticsearch",
|
||||
data=False,
|
||||
)
|
||||
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||
if not KnowledgebaseService.accessible4deletion(update_dict["kb_id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
@ -112,15 +116,15 @@ async def update():
|
||||
)
|
||||
try:
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=req["kb_id"]):
|
||||
created_by=current_user.id, id=update_dict["kb_id"]):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of dataset authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
e, kb = KnowledgebaseService.get_by_id(update_dict["kb_id"])
|
||||
|
||||
# Rename folder in FileService
|
||||
if e and req["name"].lower() != kb.name.lower():
|
||||
if e and update_dict["name"].lower() != kb.name.lower():
|
||||
FileService.filter_update(
|
||||
[
|
||||
File.tenant_id == kb.tenant_id,
|
||||
@ -128,33 +132,33 @@ async def update():
|
||||
File.type == "folder",
|
||||
File.name == kb.name,
|
||||
],
|
||||
{"name": req["name"]},
|
||||
{"name": update_dict["name"]},
|
||||
)
|
||||
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this dataset!")
|
||||
|
||||
if req["name"].lower() != kb.name.lower() \
|
||||
if update_dict["name"].lower() != kb.name.lower() \
|
||||
and len(
|
||||
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
||||
KnowledgebaseService.query(name=update_dict["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
||||
return get_data_error_result(
|
||||
message="Duplicated dataset name.")
|
||||
|
||||
del req["kb_id"]
|
||||
del update_dict["kb_id"]
|
||||
connectors = []
|
||||
if "connectors" in req:
|
||||
connectors = req["connectors"]
|
||||
del req["connectors"]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
if "connectors" in update_dict:
|
||||
connectors = update_dict["connectors"]
|
||||
del update_dict["connectors"]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, update_dict):
|
||||
return get_data_error_result()
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
if kb.pagerank != update_dict.get("pagerank", 0):
|
||||
if update_dict.get("pagerank", 0) > 0:
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.update,
|
||||
{"kb_id": kb.id},
|
||||
{PAGERANK_FLD: req["pagerank"]},
|
||||
{PAGERANK_FLD: update_dict["pagerank"]},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb.id,
|
||||
)
|
||||
@ -176,7 +180,7 @@ async def update():
|
||||
if errors:
|
||||
logging.error("Link KB errors: ", errors)
|
||||
kb = kb.to_dict()
|
||||
kb.update(req)
|
||||
kb.update(update_dict)
|
||||
kb["connectors"] = connectors
|
||||
|
||||
return get_json_result(data=kb)
|
||||
@ -943,12 +947,18 @@ async def check_embedding():
|
||||
return s if s else "None"
|
||||
req = await get_request_json()
|
||||
kb_id = req.get("kb_id", "")
|
||||
tenant_embd_id = req.get("tenant_embd_id")
|
||||
embd_id = req.get("embd_id", "")
|
||||
n = int(req.get("check_num", 5))
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
tenant_id = kb.tenant_id
|
||||
|
||||
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
if tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(tenant_embd_id)
|
||||
elif embd_id:
|
||||
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
else:
|
||||
return get_error_data_result("`tenant_embd_id` or `embd_id` is required.")
|
||||
emb_mdl = LLMBundle(tenant_id, embd_model_config)
|
||||
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
|
||||
|
||||
results, eff_sims = [], []
|
||||
|
||||
Reference in New Issue
Block a user