mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-21 08:37:05 +08:00
### Related issues Closes #14781 ### What problem does this PR solve? Some retrieval endpoints accepted caller-supplied `tenant_rerank_id` and resolved it through `get_model_config_by_id(...)`. That helper loaded `TenantLLM` rows by global database id and returned decoded model configuration without checking whether the model belonged to the authenticated tenant or the dataset owner tenant. This meant dataset access was validated, but rerank-model selection was not. A caller who knew or could guess another tenant's `tenant_rerank_id` could attempt retrieval with a foreign rerank model config, creating a cross-tenant authorization gap for model usage. This PR closes that gap by making `tenant_rerank_id` resolution tenant-aware across the retrieval paths that accept it. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): ### Solution - Extend `get_model_config_by_id(...)` to accept an optional `allowed_tenant_ids` set and reject `TenantLLM` rows whose `tenant_id` is outside that set. - Pass the allowed tenant scope from retrieval endpoints that accept `tenant_rerank_id`: - `api/apps/sdk/doc.py` - `api/apps/sdk/session.py` - `api/apps/services/dataset_api_service.py` - Use the authenticated tenant plus dataset-owner tenant ids already derived by each retrieval flow as the authorization boundary for rerank model selection. - Add focused unit coverage to assert unauthorized `tenant_rerank_id` values are rejected and that the allowed tenant set is propagated correctly. ### Testing - `python -m py_compile` on: - `api/db/joint_services/tenant_model_service.py` - `api/apps/services/dataset_api_service.py` - `api/apps/sdk/doc.py` - `api/apps/sdk/session.py` - Added unit tests in: - `test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py` - `test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py` ### Notes for reviewers - This change is intentionally narrow: it affects only the `tenant_rerank_id` path, not the normal `rerank_id` name-based resolution path. - Local lint/syntax checks passed. - Full pytest execution could not be completed in this environment because the local test runtime is missing `strenum`, so the route-test files fail during collection before exercising the updated cases. --------- Co-authored-by: jony376 <jony376@gmail.com>
624 lines
26 KiB
Python
624 lines
26 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import json
|
|
import re
|
|
|
|
import logging
|
|
|
|
from quart import Response, request
|
|
|
|
from agent.canvas import Canvas
|
|
from api.db.db_models import APIToken
|
|
from api.db.services.api_service import API4ConversationService
|
|
from api.db.services.canvas_service import UserCanvasService
|
|
from api.db.services.canvas_service import completion as agent_completion
|
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
|
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
|
from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap
|
|
from api.db.services.doc_metadata_service import DocMetadataService
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
from api.db.services.llm_service import LLMBundle
|
|
from common.metadata_utils import apply_meta_data_filter
|
|
from api.db.services.search_service import SearchService
|
|
from api.db.services.user_service import UserTenantService
|
|
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \
|
|
get_model_config_by_type_and_name
|
|
from common.misc_utils import get_uuid, thread_pool_exec
|
|
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \
|
|
get_result, get_request_json, server_error_response, token_required, validate_request
|
|
from rag.app.tag import label_question
|
|
from rag.prompts.template import load_prompt
|
|
from rag.prompts.generator import cross_languages, keyword_extraction
|
|
from common.constants import RetCode, LLMType, StatusEnum
|
|
from common import settings
|
|
from api.utils.reference_metadata_utils import (
|
|
enrich_chunks_with_document_metadata,
|
|
resolve_reference_metadata_preferences,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@token_required
|
|
async def create_agent_session(tenant_id, agent_id):
|
|
req = await get_request_json()
|
|
user_id = req.get("user_id") or request.args.get("user_id", tenant_id)
|
|
release_mode = bool(req.get("release", request.args.get("release", False)))
|
|
|
|
if not await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id):
|
|
return get_error_data_result("You cannot access the agent.")
|
|
|
|
try:
|
|
cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode, tenant_id)
|
|
except LookupError:
|
|
return get_error_data_result("Agent not found.")
|
|
except PermissionError as e:
|
|
return get_error_data_result(str(e))
|
|
|
|
session_id = get_uuid()
|
|
canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id)
|
|
canvas.reset()
|
|
|
|
cvs.dsl = json.loads(str(canvas))
|
|
# Get the version title based on release_mode
|
|
version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode)
|
|
conv = {
|
|
"id": session_id,
|
|
"dialog_id": cvs.id,
|
|
"user_id": user_id,
|
|
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
|
"source": "agent",
|
|
"dsl": cvs.dsl,
|
|
"version_title": version_title
|
|
}
|
|
await thread_pool_exec(API4ConversationService.save, **conv)
|
|
conv["agent_id"] = conv.pop("dialog_id")
|
|
return get_result(data=conv)
|
|
|
|
|
|
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
|
@token_required
|
|
async def delete_agent_session(tenant_id, agent_id):
|
|
errors = []
|
|
success_count = 0
|
|
req = await get_request_json()
|
|
cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id)
|
|
if not cvs:
|
|
return get_error_data_result(f"You don't own the agent {agent_id}")
|
|
|
|
if not req:
|
|
return get_result()
|
|
|
|
ids = req.get("ids")
|
|
if not ids:
|
|
if req.get("delete_all") is True:
|
|
ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)]
|
|
if not ids:
|
|
return get_result()
|
|
else:
|
|
return get_result()
|
|
|
|
conv_list = ids
|
|
|
|
unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session")
|
|
conv_list = unique_conv_ids
|
|
|
|
for session_id in conv_list:
|
|
conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id)
|
|
if not conv:
|
|
errors.append(f"The agent doesn't own the session {session_id}")
|
|
continue
|
|
await thread_pool_exec(API4ConversationService.delete_by_id, session_id)
|
|
success_count += 1
|
|
|
|
if errors:
|
|
if success_count > 0:
|
|
return get_result(data={"success_count": success_count, "errors": errors},
|
|
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
|
else:
|
|
return get_error_data_result(message="; ".join(errors))
|
|
|
|
if duplicate_messages:
|
|
if success_count > 0:
|
|
return get_result(
|
|
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
|
data={"success_count": success_count, "errors": duplicate_messages})
|
|
else:
|
|
return get_error_data_result(message=";".join(duplicate_messages))
|
|
|
|
return get_result()
|
|
|
|
|
|
|
|
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
|
async def chatbot_completions(dialog_id):
|
|
req = await get_request_json()
|
|
|
|
token = request.headers.get("Authorization").split()
|
|
if len(token) != 2:
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
token = token[1]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
tenant_id = objs[0].tenant_id
|
|
exists, dialog = DialogService.get_by_id(dialog_id)
|
|
if (not exists
|
|
or getattr(dialog, "tenant_id", None) != tenant_id
|
|
or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value):
|
|
logger.warning(
|
|
"Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
|
|
"no access to this chatbot",
|
|
tenant_id,
|
|
dialog_id,
|
|
req.get("user_id"),
|
|
req.get("session_id"),
|
|
)
|
|
return get_error_data_result(message="Authentication error: no access to this chatbot!")
|
|
|
|
if "quote" not in req:
|
|
req["quote"] = False
|
|
|
|
def _validate_iframe_access():
|
|
if req.get("session_id"):
|
|
exists, conv = API4ConversationService.get_by_id(req.get("session_id"))
|
|
if not exists:
|
|
raise AssertionError("Session not found!")
|
|
if conv.dialog_id != dialog_id:
|
|
raise AssertionError("Session does not belong to this dialog")
|
|
if tenant_id and conv.user_id and conv.user_id != tenant_id:
|
|
raise AssertionError("Session does not belong to this tenant")
|
|
|
|
if req.get("stream", True):
|
|
try:
|
|
_validate_iframe_access()
|
|
except AssertionError:
|
|
logger.warning(
|
|
"Denied chatbot completion stream: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
|
|
"no access to this chatbot",
|
|
tenant_id,
|
|
dialog_id,
|
|
req.get("user_id"),
|
|
req.get("session_id"),
|
|
)
|
|
return get_error_data_result(message="Authentication error: no access to this chatbot!")
|
|
|
|
resp = Response(iframe_completion(dialog_id, tenant_id=tenant_id, **req), mimetype="text/event-stream")
|
|
resp.headers.add_header("Cache-control", "no-cache")
|
|
resp.headers.add_header("Connection", "keep-alive")
|
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
|
return resp
|
|
|
|
try:
|
|
_validate_iframe_access()
|
|
async for answer in iframe_completion(dialog_id, tenant_id=tenant_id, **req):
|
|
return get_result(data=answer)
|
|
except AssertionError:
|
|
logger.warning(
|
|
"Denied chatbot completion: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
|
|
"no access to this chatbot",
|
|
tenant_id,
|
|
dialog_id,
|
|
req.get("user_id"),
|
|
req.get("session_id"),
|
|
)
|
|
return get_error_data_result(message="Authentication error: no access to this chatbot!")
|
|
|
|
return None
|
|
|
|
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
|
async def chatbots_inputs(dialog_id):
|
|
token = request.headers.get("Authorization").split()
|
|
if len(token) != 2:
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
token = token[1]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
tenant_id = objs[0].tenant_id
|
|
exists, dialog = await thread_pool_exec(DialogService.get_by_id, dialog_id)
|
|
if (not exists
|
|
or getattr(dialog, "tenant_id", None) != tenant_id
|
|
or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value):
|
|
request_args = getattr(request, "args", {}) or {}
|
|
request_user_id = request_args.get("user_id") if hasattr(request_args, "get") else None
|
|
request_session_id = request_args.get("session_id") if hasattr(request_args, "get") else None
|
|
logger.warning(
|
|
"Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
|
|
"no access to this chatbot",
|
|
tenant_id,
|
|
dialog_id,
|
|
request_user_id,
|
|
request_session_id,
|
|
)
|
|
return get_error_data_result(message="Authentication error: no access to this chatbot!")
|
|
return get_result(
|
|
data={
|
|
"title": dialog.name,
|
|
"avatar": dialog.icon,
|
|
"prologue": dialog.prompt_config.get("prologue", ""),
|
|
"has_tavily_key": bool(dialog.prompt_config.get("tavily_api_key", "").strip()),
|
|
}
|
|
)
|
|
|
|
|
|
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
|
async def agent_bot_completions(agent_id):
|
|
req = await get_request_json()
|
|
|
|
token = request.headers.get("Authorization").split()
|
|
if len(token) != 2:
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
token = token[1]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
|
|
if req.get("stream", True):
|
|
async def stream():
|
|
try:
|
|
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
|
yield answer
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
error_result = get_error_data_result(message=str(e) or "Unknown error")
|
|
yield "data:" + json.dumps(
|
|
{
|
|
"event": "message",
|
|
"data": {"content": f"Error {error_result['code']}: {error_result['message']}\n\n"},
|
|
**error_result,
|
|
},
|
|
ensure_ascii=False,
|
|
) + "\n\n"
|
|
|
|
resp = Response(stream(), mimetype="text/event-stream")
|
|
resp.headers.add_header("Cache-control", "no-cache")
|
|
resp.headers.add_header("Connection", "keep-alive")
|
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
|
return resp
|
|
|
|
try:
|
|
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
|
return get_result(data=answer)
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
return get_error_data_result(message=str(e) or "Unknown error")
|
|
|
|
return None
|
|
|
|
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
|
async def begin_inputs(agent_id):
|
|
token = request.headers.get("Authorization").split()
|
|
if len(token) != 2:
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
token = token[1]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
|
|
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id)
|
|
if not e:
|
|
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
|
|
|
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id, canvas_id=cvs.id)
|
|
return get_result(
|
|
data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"),
|
|
"prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
|
|
|
|
|
|
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
|
@validate_request("question", "kb_ids")
|
|
async def ask_about_embedded():
|
|
token = request.headers.get("Authorization").split()
|
|
if len(token) != 2:
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
token = token[1]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
|
|
req = await get_request_json()
|
|
uid = objs[0].tenant_id
|
|
|
|
search_id = req.get("search_id", "")
|
|
search_config = {}
|
|
if search_id:
|
|
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
|
|
search_config = search_app.get("search_config", {})
|
|
|
|
async def stream():
|
|
nonlocal req, uid
|
|
try:
|
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
|
except Exception as e:
|
|
yield "data:" + json.dumps(
|
|
{"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
|
ensure_ascii=False) + "\n\n"
|
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
|
|
|
resp = Response(stream(), mimetype="text/event-stream")
|
|
resp.headers.add_header("Cache-control", "no-cache")
|
|
resp.headers.add_header("Connection", "keep-alive")
|
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
|
return resp
|
|
|
|
|
|
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
|
|
@validate_request("kb_id", "question")
|
|
async def retrieval_test_embedded():
|
|
token = request.headers.get("Authorization").split()
|
|
if len(token) != 2:
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
token = token[1]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
|
|
req = await get_request_json()
|
|
page = int(req.get("page", 1))
|
|
size = int(req.get("size", 30))
|
|
question = req["question"]
|
|
kb_ids = req["kb_id"]
|
|
if isinstance(kb_ids, str):
|
|
kb_ids = [kb_ids]
|
|
if not kb_ids:
|
|
return get_json_result(data=False, message='Please specify dataset firstly.',
|
|
code=RetCode.DATA_ERROR)
|
|
doc_ids = req.get("doc_ids", [])
|
|
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
|
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
|
use_kg = req.get("use_kg", False)
|
|
top = int(req.get("top_k", 1024))
|
|
if top <= 0:
|
|
return get_error_data_result("`top_k` must be greater than 0")
|
|
langs = req.get("cross_languages", [])
|
|
rerank_id = req.get("rerank_id", "")
|
|
tenant_rerank_id = req.get("tenant_rerank_id", "")
|
|
tenant_id = objs[0].tenant_id
|
|
if not tenant_id:
|
|
return get_error_data_result(message="permission denined.")
|
|
search_config = {}
|
|
|
|
async def _retrieval():
|
|
nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id
|
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
|
tenant_ids = []
|
|
_question = question
|
|
|
|
meta_data_filter = {}
|
|
chat_mdl = None
|
|
if req.get("search_id", ""):
|
|
nonlocal search_config
|
|
detail = await thread_pool_exec(SearchService.get_detail, req.get("search_id", ""))
|
|
if detail:
|
|
search_config = detail.get("search_config", {})
|
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
|
chat_id = search_config.get("chat_id", "")
|
|
if chat_id:
|
|
chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id)
|
|
else:
|
|
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
|
|
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
|
# Apply search_config settings if not explicitly provided in request
|
|
if not req.get("similarity_threshold"):
|
|
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
|
|
if not req.get("vector_similarity_weight"):
|
|
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
|
|
if not req.get("top_k"):
|
|
top = int(search_config.get("top_k", top))
|
|
if not req.get("rerank_id"):
|
|
rerank_id = search_config.get("rerank_id", "")
|
|
else:
|
|
meta_data_filter = req.get("meta_data_filter") or {}
|
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
|
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
|
|
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
|
|
|
if meta_data_filter:
|
|
local_doc_ids = await apply_meta_data_filter(
|
|
meta_data_filter,
|
|
None,
|
|
_question,
|
|
chat_mdl,
|
|
local_doc_ids,
|
|
kb_ids=kb_ids,
|
|
metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
|
|
)
|
|
|
|
tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id)
|
|
for kb_id in kb_ids:
|
|
for tenant in tenants:
|
|
if await thread_pool_exec(KnowledgebaseService.query, tenant_id=tenant.tenant_id, id=kb_id):
|
|
tenant_ids.append(tenant.tenant_id)
|
|
break
|
|
else:
|
|
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.",
|
|
code=RetCode.OPERATING_ERROR)
|
|
|
|
e, kb = await thread_pool_exec(KnowledgebaseService.get_by_id, kb_ids[0])
|
|
if not e:
|
|
return get_error_data_result(message="Knowledgebase not found!")
|
|
|
|
if langs:
|
|
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
|
if kb.tenant_embd_id:
|
|
embd_model_config = await thread_pool_exec(get_model_config_by_id, kb.tenant_embd_id)
|
|
else:
|
|
embd_model_config = await thread_pool_exec(get_model_config_by_type_and_name, kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
|
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
|
|
|
|
rerank_mdl = None
|
|
if tenant_rerank_id:
|
|
allowed_rerank_tenant_ids = {tenant_id, *tenant_ids}
|
|
rerank_model_config = await thread_pool_exec(
|
|
get_model_config_by_id,
|
|
tenant_rerank_id,
|
|
allowed_rerank_tenant_ids,
|
|
tenant_id,
|
|
)
|
|
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
|
elif rerank_id:
|
|
rerank_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.RERANK, rerank_id)
|
|
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
|
|
|
if req.get("keyword", False):
|
|
default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT)
|
|
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model)
|
|
_question += await keyword_extraction(chat_mdl, _question)
|
|
|
|
labels = label_question(_question, [kb])
|
|
ranks = await settings.retriever.retrieval(
|
|
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
|
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
|
)
|
|
if use_kg:
|
|
default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT)
|
|
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
|
LLMBundle(kb.tenant_id, default_chat_model))
|
|
if ck["content_with_weight"]:
|
|
ranks["chunks"].insert(0, ck)
|
|
|
|
for c in ranks["chunks"]:
|
|
c.pop("vector", None)
|
|
|
|
include_metadata, metadata_fields = _resolve_reference_metadata(req, search_config)
|
|
if include_metadata:
|
|
enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields)
|
|
|
|
ranks["labels"] = labels
|
|
|
|
return get_json_result(data=ranks)
|
|
|
|
try:
|
|
return await _retrieval()
|
|
except Exception as e:
|
|
if str(e).find("not_found") > 0:
|
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
|
code=RetCode.DATA_ERROR)
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821
|
|
@validate_request("question")
|
|
async def related_questions_embedded():
|
|
token = request.headers.get("Authorization").split()
|
|
if len(token) != 2:
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
token = token[1]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
|
|
req = await get_request_json()
|
|
tenant_id = objs[0].tenant_id
|
|
if not tenant_id:
|
|
return get_error_data_result(message="permission denined.")
|
|
|
|
search_id = req.get("search_id", "")
|
|
search_config = {}
|
|
if search_id:
|
|
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
|
|
search_config = search_app.get("search_config", {})
|
|
|
|
question = req["question"]
|
|
|
|
chat_id = search_config.get("chat_id", "")
|
|
if chat_id:
|
|
chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id)
|
|
else:
|
|
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
|
|
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
|
|
|
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
|
prompt = load_prompt("related_question")
|
|
ans = await chat_mdl.async_chat(
|
|
prompt,
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": f"""
|
|
Keywords: {question}
|
|
Related search terms:
|
|
""",
|
|
}
|
|
],
|
|
gen_conf,
|
|
)
|
|
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
|
|
|
|
|
|
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
|
|
async def detail_share_embedded():
|
|
token = request.headers.get("Authorization").split()
|
|
if len(token) != 2:
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
token = token[1]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
|
|
search_id = request.args["search_id"]
|
|
tenant_id = objs[0].tenant_id
|
|
if not tenant_id:
|
|
return get_error_data_result(message="permission denined.")
|
|
try:
|
|
tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id)
|
|
for tenant in tenants:
|
|
if await thread_pool_exec(SearchService.query, tenant_id=tenant.tenant_id, id=search_id):
|
|
break
|
|
else:
|
|
return get_json_result(data=False, message="Has no permission for this operation.",
|
|
code=RetCode.OPERATING_ERROR)
|
|
|
|
search = await thread_pool_exec(SearchService.get_detail, search_id)
|
|
if not search:
|
|
return get_error_data_result(message="Can't find this Search App!")
|
|
return get_json_result(data=search)
|
|
except Exception as e:
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821
|
|
@validate_request("question", "kb_ids")
|
|
async def mindmap():
|
|
token = request.headers.get("Authorization").split()
|
|
if len(token) != 2:
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
token = token[1]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
|
|
tenant_id = objs[0].tenant_id
|
|
req = await get_request_json()
|
|
|
|
search_id = req.get("search_id", "")
|
|
search_app = await thread_pool_exec(SearchService.get_detail, search_id) if search_id else {}
|
|
|
|
mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
|
if "error" in mind_map:
|
|
return server_error_response(Exception(mind_map["error"]))
|
|
return get_json_result(data=mind_map)
|
|
|
|
|
|
def _resolve_reference_metadata(req, search_config=None):
|
|
return resolve_reference_metadata_preferences(req, search_config)
|