From b1d28b5898358da80eb18e32437cc81574537ccb Mon Sep 17 00:00:00 2001 From: Liu An Date: Wed, 1 Apr 2026 11:05:29 +0800 Subject: [PATCH] Revert "Refa: Chats /chat API to RESTFul (#13871)" (#13877) ### What problem does this PR solve? This reverts commit 1a608ac411877902ad0b8d918749b541c57eb6b3. ### Type of change - [x] Other (please describe): --- admin/client/ragflow_client.py | 120 ++-- api/apps/dialog_app.py | 253 ++++++++ api/apps/restful_apis/chat_api.py | 568 ----------------- api/apps/sdk/chat.py | 329 ++++++++++ api/db/services/dialog_service.py | 44 +- docs/references/http_api_reference.md | 487 ++++----------- docs/references/python_api_reference.md | 145 ++--- sdk/python/ragflow_sdk/modules/base.py | 4 - sdk/python/ragflow_sdk/modules/chat.py | 50 +- sdk/python/ragflow_sdk/ragflow.py | 95 +-- test/benchmark/README.md | 4 +- test/benchmark/chat.py | 19 +- test/benchmark/run_chat.sh | 2 +- test/benchmark/run_retrieval_chat.sh | 2 +- test/playwright/e2e/test_next_apps_chat.py | 13 +- test/playwright/helpers/_next_apps_helpers.py | 2 +- test/testcases/test_http_api/common.py | 12 - .../conftest.py | 6 +- .../test_chat_sdk_routes_unit.py | 546 +++++++---------- .../test_delete_chat_assistants.py | 6 +- .../test_list_chat_assistants.py | 154 ++--- .../test_update_chat_assistant.py | 190 +++--- .../test_create_chat_assistant.py | 130 ++-- .../test_list_chat_assistants.py | 76 ++- .../test_update_chat_assistant.py | 163 +++-- test/testcases/test_web_api/common.py | 98 +++ .../test_web_api/test_dialog_app/conftest.py | 50 ++ .../test_dialog_app/test_create_dialog.py | 170 ++++++ .../test_dialog_app/test_delete_dialogs.py | 204 +++++++ .../test_dialog_app/test_dialog_edge_cases.py | 205 +++++++ .../test_dialog_routes_unit.py | 572 ++++++++++++++++++ .../test_dialog_app/test_get_dialog.py | 177 ++++++ .../test_dialog_app/test_list_dialogs.py | 210 +++++++ .../test_dialog_app/test_update_dialog.py | 170 ++++++ web/src/components/knowledge-base-item.tsx | 2 +- web/src/hooks/use-chat-request.ts | 140 ++--- web/src/interfaces/database/chat.ts | 10 +- web/src/pages/home/chat-list.tsx | 8 +- web/src/pages/next-chats/chat-dropdown.tsx | 8 +- .../chat/app-settings/chat-settings.tsx | 30 +- .../app-settings/use-chat-setting-schema.tsx | 2 +- .../chat/chat-box/next-multiple-chat-box.tsx | 22 +- .../chat/chat-box/single-chat-box.tsx | 7 +- .../pages/next-chats/chat/llm-select-form.tsx | 5 +- web/src/pages/next-chats/chat/sessions.tsx | 4 +- .../next-chats/chat/use-show-internet.ts | 4 +- .../pages/next-chats/hooks/use-rename-chat.ts | 32 +- .../hooks/use-select-conversation-list.ts | 6 +- web/src/pages/next-chats/index.tsx | 10 +- web/src/services/next-chat-service.ts | 48 +- web/src/utils/api.ts | 12 +- web/src/utils/llm-util.ts | 2 +- 52 files changed, 3584 insertions(+), 2044 deletions(-) create mode 100644 api/apps/dialog_app.py delete mode 100644 api/apps/restful_apis/chat_api.py create mode 100644 api/apps/sdk/chat.py create mode 100644 test/testcases/test_web_api/test_dialog_app/conftest.py create mode 100644 test/testcases/test_web_api/test_dialog_app/test_create_dialog.py create mode 100644 test/testcases/test_web_api/test_dialog_app/test_delete_dialogs.py create mode 100644 test/testcases/test_web_api/test_dialog_app/test_dialog_edge_cases.py create mode 100644 test/testcases/test_web_api/test_dialog_app/test_dialog_routes_unit.py create mode 100644 test/testcases/test_web_api/test_dialog_app/test_get_dialog.py create mode 100644 test/testcases/test_web_api/test_dialog_app/test_list_dialogs.py create mode 100644 test/testcases/test_web_api/test_dialog_app/test_update_dialog.py diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index 55e10e041..03d9b8ded 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -977,13 +977,76 @@ class RAGFlowClient: def create_user_chat(self, command): if self.server_type != "user": print("This command is only allowed in USER mode") + ''' + description + : + "" + icon + : + "" + language + : + "English" + llm_id + : + "glm-4-flash@ZHIPU-AI" + llm_setting + : + {} + name + : + "xx" + prompt_config + : + {empty_response: "", prologue: "Hi! I'm your assistant. What can I do for you?", quote: true,…} + empty_response + : + "" + keyword + : + false + parameters + : + [{key: "knowledge", optional: false}] + prologue + : + "Hi! I'm your assistant. What can I do for you?" + quote + : + true + reasoning + : + false + refine_multiturn + : + false + system + : + "You are an intelligent assistant. Your primary function is to answer questions based strictly on the provided knowledge base.\n\n **Essential Rules:**\n - Your answer must be derived **solely** from this knowledge base: `{knowledge}`.\n - **When information is available**: Summarize the content to give a detailed answer.\n - **When information is unavailable**: Your response must contain this exact sentence: \"The answer you are looking for is not found in the knowledge base!\"\n - **Always consider** the entire conversation history." + toc_enhance + : + false + tts + : + false + use_kg + : + false + similarity_threshold + : + 0.2 + top_n + : + 8 + vector_similarity_weight + : + 0.3 + ''' chat_name = command["chat_name"] - default_models = self._get_default_models() or {} payload = { - "name": chat_name, "description": "", "icon": "", - "dataset_ids": [], + "language": "English", "llm_setting": {}, "prompt_config": { "empty_response": "", @@ -1001,24 +1064,16 @@ class RAGFlowClient: "optional": False } ], - "toc_enhance": False, + "toc_enhance": False }, "similarity_threshold": 0.2, "top_n": 8, - "top_k": 1024, - "vector_similarity_weight": 0.3, - "rerank_id": default_models.get("rerank_id", ""), + "vector_similarity_weight": 0.3 } - if default_models.get("llm_id"): - payload["llm_id"] = default_models["llm_id"] - response = self.http_client.request( - "POST", - "/chats", - json_body=payload, - use_api_base=True, - auth_kind="web", - ) + payload.update({"name": chat_name}) + response = self.http_client.request("POST", "/dialog/set", json_body=payload, use_api_base=False, + auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: print(f"Success to create chat: {chat_name}") @@ -1103,14 +1158,9 @@ class RAGFlowClient: for elem in res_json: if elem["name"] == chat_name: to_drop_chat_ids.append(elem["id"]) - payload = {"ids": to_drop_chat_ids} - response = self.http_client.request( - "DELETE", - "/chats", - json_body=payload, - use_api_base=True, - auth_kind="web", - ) + payload = {"dialog_ids": to_drop_chat_ids} + response = self.http_client.request("POST", "/dialog/rm", json_body=payload, use_api_base=False, + auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: print(f"Success to drop chat: {chat_name}") @@ -1572,27 +1622,17 @@ class RAGFlowClient: def _list_chats(self, command): iterations = command.get("iterations", 1) if iterations > 1: - response = self.http_client.request( - "GET", - "/chats", - use_api_base=True, - auth_kind="web", - iterations=iterations, - ) + response = self.http_client.request("POST", "/dialog/next", use_api_base=False, auth_kind="web", + iterations=iterations) return response else: - response = self.http_client.request( - "GET", - "/chats", - use_api_base=True, - auth_kind="web", - iterations=iterations, - ) + response = self.http_client.request("POST", "/dialog/next", use_api_base=False, auth_kind="web", + iterations=iterations) res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: - return res_json["data"]["chats"] + return res_json["data"]["dialogs"] else: - print(f"Fail to list chats, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to list datasets, code: {res_json['code']}, message: {res_json['message']}") return None def _get_default_models(self): diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py new file mode 100644 index 000000000..043fb39de --- /dev/null +++ b/api/apps/dialog_app.py @@ -0,0 +1,253 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from quart import request +from api.db.services import duplicate_name +from api.db.services.dialog_service import DialogService +from common.constants import StatusEnum +from api.db.services.tenant_llm_service import TenantLLMService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.user_service import TenantService, UserTenantService +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request +from api.utils.tenant_utils import ensure_tenant_model_id_for_params +from common.misc_utils import get_uuid +from common.constants import RetCode +from api.apps import login_required, current_user +import logging + + +@manager.route('/set', methods=['POST']) # noqa: F821 +@validate_request("prompt_config") +@login_required +async def set_dialog(): + req = await get_request_json() + dialog_info = ensure_tenant_model_id_for_params(current_user.id, req) + dialog_id = dialog_info.get("dialog_id", "") + is_create = not dialog_id + name = dialog_info.get("name", "New Dialog") + if not isinstance(name, str): + return get_data_error_result(message="Dialog name must be string.") + if name.strip() == "": + return get_data_error_result(message="Dialog name can't be empty.") + if len(name.encode("utf-8")) > 255: + return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255") + + name = name.strip() + if is_create: + # only for chat creating + existing_names = { + d.name.casefold() + for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value) + if d.name + } + if name.casefold() in existing_names: + def _name_exists(name: str, **_kwargs) -> bool: + return name.casefold() in existing_names + + name = duplicate_name(_name_exists, name=name) + + description = dialog_info.get("description", "A helpful dialog") + icon = dialog_info.get("icon", "") + top_n = dialog_info.get("top_n", 6) + top_k = dialog_info.get("top_k", 1024) + rerank_id = dialog_info.get("rerank_id", "") + if not rerank_id: + dialog_info["rerank_id"] = "" + similarity_threshold = dialog_info.get("similarity_threshold", 0.1) + vector_similarity_weight = dialog_info.get("vector_similarity_weight", 0.3) + llm_setting = dialog_info.get("llm_setting", {}) + meta_data_filter = dialog_info.get("meta_data_filter", {}) + prompt_config = dialog_info["prompt_config"] + + # Set default parameters for datasets with knowledge retrieval + # All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval + kb_ids = dialog_info.get("kb_ids", []) + parameters = prompt_config.get("parameters") + logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}") + # Check if parameters is missing, None, or empty list + if kb_ids and not parameters: + # Check if system prompt uses {knowledge} placeholder + if "{knowledge}" in prompt_config.get("system", ""): + # Set default parameters for any dataset with knowledge placeholder + prompt_config["parameters"] = [{"key": "knowledge", "optional": False}] + logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}") + + if not is_create: + # only for chat updating + if not dialog_info.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""): + return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") + + for p in prompt_config.get("parameters", []): + if p["optional"]: + continue + if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0: + return get_data_error_result( + message="Parameter '{}' is not used".format(p["key"])) + + try: + e, tenant = TenantService.get_by_id(current_user.id) + if not e: + return get_data_error_result(message="Tenant not found!") + kbs = KnowledgebaseService.get_by_ids(dialog_info.get("kb_ids", [])) + embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison + embd_count = len(set(embd_ids)) + if embd_count > 1: + return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"') + + llm_id = dialog_info.get("llm_id", tenant.llm_id) + tenant_llm_id = dialog_info.get("tenant_llm_id", tenant.tenant_llm_id) + if not dialog_id: + dia = { + "id": get_uuid(), + "tenant_id": current_user.id, + "name": name, + "kb_ids": dialog_info.get("kb_ids", []), + "description": description, + "llm_id": llm_id, + "tenant_llm_id": tenant_llm_id, + "llm_setting": llm_setting, + "prompt_config": prompt_config, + "meta_data_filter": meta_data_filter, + "top_n": top_n, + "top_k": top_k, + "rerank_id": rerank_id, + "tenant_rerank_id": dialog_info.get("tenant_rerank_id", 0), + "similarity_threshold": similarity_threshold, + "vector_similarity_weight": vector_similarity_weight, + "icon": icon + } + if not DialogService.save(**dia): + return get_data_error_result(message="Fail to new a dialog!") + return get_json_result(data=dia) + else: + del dialog_info["dialog_id"] + if "kb_names" in dialog_info: + del dialog_info["kb_names"] + if not DialogService.update_by_id(dialog_id, dialog_info): + return get_data_error_result(message="Dialog not found!") + e, dia = DialogService.get_by_id(dialog_id) + if not e: + return get_data_error_result(message="Fail to update a dialog!") + dia = dia.to_dict() + dia.update(dialog_info) + dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) + return get_json_result(data=dia) + except Exception as e: + return server_error_response(e) + + +@manager.route('/get', methods=['GET']) # noqa: F821 +@login_required +def get(): + dialog_id = request.args["dialog_id"] + try: + e, dia = DialogService.get_by_id(dialog_id) + if not e: + return get_data_error_result(message="Dialog not found!") + dia = dia.to_dict() + dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) + return get_json_result(data=dia) + except Exception as e: + return server_error_response(e) + + +def get_kb_names(kb_ids): + ids, nms = [], [] + for kid in kb_ids: + e, kb = KnowledgebaseService.get_by_id(kid) + if not e or kb.status != StatusEnum.VALID.value: + continue + ids.append(kid) + nms.append(kb.name) + return ids, nms + + +@manager.route('/list', methods=['GET']) # noqa: F821 +@login_required +def list_dialogs(): + try: + conversations = DialogService.query( + tenant_id=current_user.id, + status=StatusEnum.VALID.value, + reverse=True, + order_by=DialogService.model.create_time) + conversations = [d.to_dict() for d in conversations] + for conversation in conversations: + conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"]) + return get_json_result(data=conversations) + except Exception as e: + return server_error_response(e) + + +@manager.route('/next', methods=['POST']) # noqa: F821 +@login_required +async def list_dialogs_next(): + args = request.args + keywords = args.get("keywords", "") + page_number = int(args.get("page", 0)) + items_per_page = int(args.get("page_size", 0)) + parser_id = args.get("parser_id") + orderby = args.get("orderby", "create_time") + if args.get("desc", "true").lower() == "false": + desc = False + else: + desc = True + + req = await get_request_json() + owner_ids = req.get("owner_ids", []) + try: + if not owner_ids: + # tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) + # tenants = [tenant["tenant_id"] for tenant in tenants] + tenants = [] # keep it here + dialogs, total = DialogService.get_by_tenant_ids( + tenants, current_user.id, page_number, + items_per_page, orderby, desc, keywords, parser_id) + else: + tenants = owner_ids + dialogs, total = DialogService.get_by_tenant_ids( + tenants, current_user.id, 0, + 0, orderby, desc, keywords, parser_id) + dialogs = [dialog for dialog in dialogs if dialog["tenant_id"] in tenants] + total = len(dialogs) + if page_number and items_per_page: + dialogs = dialogs[(page_number-1)*items_per_page:page_number*items_per_page] + return get_json_result(data={"dialogs": dialogs, "total": total}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/rm', methods=['POST']) # noqa: F821 +@login_required +@validate_request("dialog_ids") +async def rm(): + req = await get_request_json() + dialog_list=[] + tenants = UserTenantService.query(user_id=current_user.id) + try: + for id in req["dialog_ids"]: + for tenant in tenants: + if DialogService.query(tenant_id=tenant.tenant_id, id=id): + break + else: + return get_json_result( + data=False, message='Only owner of dialog authorized for this operation.', + code=RetCode.OPERATING_ERROR) + dialog_list.append({"id": id,"status":StatusEnum.INVALID.value}) + DialogService.update_many_by_id(dialog_list) + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py deleted file mode 100644 index 3e5a63f89..000000000 --- a/api/apps/restful_apis/chat_api.py +++ /dev/null @@ -1,568 +0,0 @@ -# -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from copy import deepcopy - -from quart import request - -from api.apps import current_user, login_required -from api.db.services.dialog_service import DialogService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.tenant_llm_service import TenantLLMService -from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import ( - check_duplicate_ids, - get_data_error_result, - get_json_result, - get_request_json, - server_error_response, -) -from api.utils.tenant_utils import ensure_tenant_model_id_for_params -from common.constants import RetCode, StatusEnum -from common.misc_utils import get_uuid - -_DEFAULT_PROMPT_CONFIG = { - "system": ( - 'You are an intelligent assistant. Please summarize the content of the dataset to answer the question. ' - 'Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the ' - 'question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" ' - "Answers need to consider chat history.\n" - " Here is the knowledge base:\n" - " {knowledge}\n" - " The above is the knowledge base." - ), - "prologue": "Hi! I'm your assistant. What can I do for you?", - "parameters": [{"key": "knowledge", "optional": False}], - "empty_response": "Sorry! No relevant content was found in the knowledge base!", - "quote": True, - "tts": False, - "refine_multiturn": True, -} -_DEFAULT_RERANK_MODELS = {"BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"} -_READONLY_FIELDS = {"id", "tenant_id", "created_by", "create_time", "create_date", "update_time", "update_date"} -_PERSISTED_FIELDS = set(DialogService.model._meta.fields) - - -def _build_chat_response(chat): - data = chat.to_dict() if hasattr(chat, "to_dict") else dict(chat) - kb_ids, kb_names = _resolve_kb_names(data.get("kb_ids", [])) - data["dataset_ids"] = kb_ids - data.pop("kb_ids", None) - data["kb_names"] = kb_names - return data - - -def _resolve_kb_names(kb_ids): - ids, names = [], [] - for kb_id in kb_ids or []: - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok or kb.status != StatusEnum.VALID.value: - continue - ids.append(kb_id) - names.append(kb.name) - return ids, names - - -def _has_knowledge_placeholder(prompt_config): - return "{knowledge}" in (prompt_config or {}).get("system", "") - - -def _validate_name(name, *, required=True): - if name is None: - if required: - return None, "`name` is required." - return None, None - if not isinstance(name, str): - return None, "Chat name must be a string." - name = name.strip() - if not name: - return None, "Chat name can't be empty." if required else "`name` cannot be empty." - if len(name.encode("utf-8")) > 255: - return None, f"Chat name length is {len(name.encode('utf-8'))} which is larger than 255." - return name, None - - -def _ensure_owned_chat(chat_id): - return DialogService.query( - tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value - ) - - -def _validate_llm_id(llm_id, tenant_id, llm_setting=None): - if not llm_id: - return None - - llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(llm_id) - model_type = (llm_setting or {}).get("model_type") - if model_type not in {"chat", "image2text"}: - model_type = "chat" - - if not TenantLLMService.query( - tenant_id=tenant_id, - llm_name=llm_name, - llm_factory=llm_factory, - model_type=model_type, - ): - return f"`llm_id` {llm_id} doesn't exist" - return None - - -def _validate_rerank_id(rerank_id, tenant_id): - if not rerank_id: - return None - llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(rerank_id) - if llm_name in _DEFAULT_RERANK_MODELS: - return None - if TenantLLMService.query( - tenant_id=tenant_id, - llm_name=llm_name, - llm_factory=llm_factory, - model_type="rerank", - ): - return None - return f"`rerank_id` {rerank_id} doesn't exist" - - -def _validate_prompt_config(prompt_config): - for parameter in prompt_config.get("parameters", []): - if parameter.get("optional"): - continue - if prompt_config.get("system", "").find("{%s}" % parameter["key"]) < 0: - return f"Parameter '{parameter['key']}' is not used" - return None - - -def _validate_dataset_ids(dataset_ids, tenant_id): - if dataset_ids is None: - return [] - if not isinstance(dataset_ids, list): - return f"`dataset_ids` should be a list." - - normalized_ids = [dataset_id for dataset_id in dataset_ids if dataset_id] - kbs = [] - for dataset_id in normalized_ids: - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return f"You don't own the dataset {dataset_id}" - matches = KnowledgebaseService.query(id=dataset_id) - if not matches: - return f"You don't own the dataset {dataset_id}" - kb = matches[0] - if kb.chunk_num == 0: - return f"The dataset {dataset_id} doesn't own parsed file" - kbs.append(kb) - - embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] - if len(set(embd_ids)) > 1: - return f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}' - - return normalized_ids - - -def _apply_prompt_defaults(req): - prompt_config = req.setdefault("prompt_config", {}) - for key, value in _DEFAULT_PROMPT_CONFIG.items(): - temp = prompt_config.get(key) - if (key == "system" and not temp) or key not in prompt_config: - prompt_config[key] = deepcopy(value) - - if req.get("kb_ids") and not prompt_config.get("parameters") and "{knowledge}" in prompt_config.get("system", ""): - prompt_config["parameters"] = [{"key": "knowledge", "optional": False}] - - -@manager.route("/chats", methods=["POST"]) # noqa: F821 -@login_required -async def create(): - try: - req = await get_request_json() - ok, tenant = TenantService.get_by_id(current_user.id) - if not ok: - return get_data_error_result(message="Tenant not found!") - - # Validate tenant_id should not be provided - if req.get("tenant_id"): - return get_data_error_result(message="`tenant_id` must not be provided.") - - # Validate name - name, err = _validate_name(req.get("name"), required=True) - if err: - return get_data_error_result(message=err) - req["name"] = name - - if "dataset_ids" in req: - kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id) - if isinstance(kb_ids, str): - return get_data_error_result(message=kb_ids) - req["kb_ids"] = kb_ids - req.pop("dataset_ids", None) - - if "llm_id" in req: - err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) - if err: - return get_data_error_result(message=err) - - if "rerank_id" in req: - err = _validate_rerank_id(req.get("rerank_id"), current_user.id) - if err: - return get_data_error_result(message=err) - - if "prompt_config" in req: - if not isinstance(req["prompt_config"], dict): - return get_data_error_result(message="`prompt_config` should be an object.") - err = _validate_prompt_config(req["prompt_config"]) - if err: - return get_data_error_result(message=err) - - req.setdefault("kb_ids", []) - req.setdefault("llm_id", tenant.llm_id) - if req["llm_id"] is None: - req["llm_id"] = tenant.llm_id - req.setdefault("llm_setting", {}) - req.setdefault("description", "A helpful Assistant") - req.setdefault("top_n", 6) - req.setdefault("top_k", 1024) - req.setdefault("rerank_id", "") - req.setdefault("similarity_threshold", 0.1) - req.setdefault("vector_similarity_weight", 0.3) - req.setdefault("icon", "") - _apply_prompt_defaults(req) - err = _validate_prompt_config(req["prompt_config"]) - if err: - return get_data_error_result(message=err) - - req = ensure_tenant_model_id_for_params(current_user.id, req) - req = {field: value for field, value in req.items() if field in _PERSISTED_FIELDS} - for field in _READONLY_FIELDS: - req.pop(field, None) - - if DialogService.query( - name=req["name"], - tenant_id=current_user.id, - status=StatusEnum.VALID.value, - ): - return get_data_error_result(message="Duplicated chat name in creating chat.") - - req["id"] = get_uuid() - req["tenant_id"] = current_user.id - if not DialogService.save(**req): - return get_data_error_result(message="Failed to create chat.") - - ok, chat = DialogService.get_by_id(req["id"]) - if not ok: - return get_data_error_result(message="Failed to retrieve created chat.") - return get_json_result(data=_build_chat_response(chat)) - except Exception as ex: - return server_error_response(ex) - - -@manager.route("/chats", methods=["GET"]) # noqa: F821 -@login_required -def list_chats(): - chat_id = request.args.get("id") - name = request.args.get("name") - keywords = request.args.get("keywords", "") - orderby = request.args.get("orderby", "create_time") - desc = request.args.get("desc", "true").lower() != "false" - owner_ids = request.args.getlist("owner_ids") - exact_filters = {"id": chat_id, "name": name} - if chat_id or name: - keywords = "" - - try: - page_number = int(request.args.get("page", 1)) - items_per_page = int(request.args.get("page_size", 0)) - if owner_ids: - chats, total = DialogService.get_by_tenant_ids( - owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters - ) - chats = [chat for chat in chats if chat["tenant_id"] in owner_ids] - total = len(chats) - if page_number and items_per_page: - start = (page_number - 1) * items_per_page - chats = chats[start : start + items_per_page] - else: - chats, total = DialogService.get_by_tenant_ids( - [], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters - ) - - return get_json_result( - data={"chats": [_build_chat_response(chat) for chat in chats], "total": total} - ) - except Exception as ex: - return server_error_response(ex) - - -@manager.route("/chats/", methods=["GET"]) # noqa: F821 -@login_required -def get_chat(chat_id): - try: - tenants = UserTenantService.query(user_id=current_user.id) - for tenant in tenants: - if DialogService.query( - tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value - ): - break - else: - return get_json_result( - data=False, - message="No authorization.", - code=RetCode.AUTHENTICATION_ERROR, - ) - - ok, chat = DialogService.get_by_id(chat_id) - if not ok: - return get_data_error_result(message="Chat not found!") - return get_json_result(data=_build_chat_response(chat)) - except Exception as ex: - return server_error_response(ex) - - -@manager.route("/chats/", methods=["PUT"]) # noqa: F821 -@login_required -async def update_chat(chat_id): - if not _ensure_owned_chat(chat_id): - return get_json_result( - data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR - ) - - try: - req = await get_request_json() - ok, tenant = TenantService.get_by_id(current_user.id) - if not ok: - return get_data_error_result(message="Tenant not found!") - - ok, current_chat = DialogService.get_by_id(chat_id) - if not ok: - return get_data_error_result(message="Chat not found!") - current_chat = current_chat.to_dict() - - if req.get("tenant_id"): - return get_data_error_result(message="`tenant_id` must not be provided.") - - if "name" in req: - name, err = _validate_name(req.get("name"), required=True) - if err: - return get_data_error_result(message=err) - req["name"] = name - - if "dataset_ids" in req: - kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id) - if isinstance(kb_ids, str): - return get_data_error_result(message=kb_ids) - req["kb_ids"] = kb_ids - req.pop("dataset_ids", None) - - if "llm_id" in req: - err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) - if err: - return get_data_error_result(message=err) - - if "rerank_id" in req: - err = _validate_rerank_id(req.get("rerank_id"), current_user.id) - if err: - return get_data_error_result(message=err) - - if "prompt_config" in req: - if not isinstance(req["prompt_config"], dict): - return get_data_error_result(message="`prompt_config` should be an object.") - err = _validate_prompt_config(req["prompt_config"]) - if err: - return get_data_error_result(message=err) - - prompt_config = req.get("prompt_config", {}) - if not prompt_config: - prompt_config = current_chat.get("prompt_config", {}) - kb_ids = req.get("kb_ids", current_chat.get("kb_ids", [])) - if not kb_ids and not prompt_config.get("tavily_api_key") and _has_knowledge_placeholder(prompt_config): - return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") - - req = ensure_tenant_model_id_for_params(current_user.id, req) - req = {field: value for field, value in req.items() if field in _PERSISTED_FIELDS} - for field in _READONLY_FIELDS: - req.pop(field, None) - - if ( - "name" in req - and req["name"].lower() != current_chat["name"].lower() - and DialogService.query( - name=req["name"], - tenant_id=current_user.id, - status=StatusEnum.VALID.value, - ) - ): - return get_data_error_result(message="Duplicated chat name.") - - if not DialogService.update_by_id(chat_id, req): - return get_data_error_result(message="Chat not found!") - - ok, chat = DialogService.get_by_id(chat_id) - if not ok: - return get_data_error_result(message="Failed to retrieve updated chat.") - return get_json_result(data=_build_chat_response(chat)) - except Exception as ex: - return server_error_response(ex) - - -@manager.route("/chats/", methods=["PATCH"]) # noqa: F821 -@login_required -async def patch_chat(chat_id): - if not _ensure_owned_chat(chat_id): - return get_json_result( - data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR - ) - - try: - req = await get_request_json() - ok, tenant = TenantService.get_by_id(current_user.id) - if not ok: - return get_data_error_result(message="Tenant not found!") - - ok, current_chat = DialogService.get_by_id(chat_id) - if not ok: - return get_data_error_result(message="Chat not found!") - current_chat = current_chat.to_dict() - - if req.get("tenant_id"): - return get_data_error_result(message="`tenant_id` must not be provided.") - - if "name" in req: - name, err = _validate_name(req.get("name"), required=False) - if err: - return get_data_error_result(message=err) - if name is not None: - req["name"] = name - - if "dataset_ids" in req: - kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id) - if isinstance(kb_ids, str): - return get_data_error_result(message=kb_ids) - req["kb_ids"] = kb_ids - req.pop("dataset_ids", None) - - if "llm_id" in req: - err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) - if err: - return get_data_error_result(message=err) - - if "rerank_id" in req: - err = _validate_rerank_id(req.get("rerank_id"), current_user.id) - if err: - return get_data_error_result(message=err) - - if "prompt_config" in req: - if not isinstance(req["prompt_config"], dict): - return get_data_error_result(message="`prompt_config` should be an object.") - prompt_config = deepcopy(current_chat.get("prompt_config", {})) - prompt_config.update(req["prompt_config"]) - req["prompt_config"] = prompt_config - err = _validate_prompt_config(prompt_config) - if err: - return get_data_error_result(message=err) - - if "llm_setting" in req: - llm_setting = deepcopy(current_chat.get("llm_setting", {})) - llm_setting.update(req["llm_setting"]) - req["llm_setting"] = llm_setting - - if "prompt_config" in req or "kb_ids" in req: - prompt_config = req.get("prompt_config", current_chat.get("prompt_config", {})) - kb_ids = req.get("kb_ids", current_chat.get("kb_ids", [])) - if not kb_ids and not prompt_config.get("tavily_api_key") and _has_knowledge_placeholder(prompt_config): - return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") - - req = ensure_tenant_model_id_for_params(current_user.id, req) - req = {field: value for field, value in req.items() if field in _PERSISTED_FIELDS} - for field in _READONLY_FIELDS: - req.pop(field, None) - - if ( - "name" in req - and req["name"].lower() != current_chat["name"].lower() - and DialogService.query( - name=req["name"], - tenant_id=current_user.id, - status=StatusEnum.VALID.value, - ) - ): - return get_data_error_result(message="Duplicated chat name.") - - if not DialogService.update_by_id(chat_id, req): - return get_data_error_result(message="Failed to update chat.") - - ok, chat = DialogService.get_by_id(chat_id) - if not ok: - return get_data_error_result(message="Failed to retrieve updated chat.") - return get_json_result(data=_build_chat_response(chat)) - except Exception as ex: - return server_error_response(ex) - - -@manager.route("/chats/", methods=["DELETE"]) # noqa: F821 -@login_required -def delete_chat(chat_id): - if not _ensure_owned_chat(chat_id): - return get_json_result( - data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR - ) - - try: - if not DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}): - return get_data_error_result(message=f"Failed to delete chat {chat_id}") - return get_json_result(data=True) - except Exception as ex: - return server_error_response(ex) - - -@manager.route("/chats", methods=["DELETE"]) # noqa: F821 -@login_required -async def bulk_delete_chats(): - req = await get_request_json() - if not req: - return get_json_result(data={}) - - ids = req.get("ids") - if not ids: - if req.get("delete_all") is True: - ids = [ - chat.id - for chat in DialogService.query( - tenant_id=current_user.id, status=StatusEnum.VALID.value - ) - ] - if not ids: - return get_json_result(data={}) - else: - return get_json_result(data={}) - - errors = [] - success_count = 0 - unique_ids, duplicate_messages = check_duplicate_ids(ids, "chat") - - for chat_id in unique_ids: - if not _ensure_owned_chat(chat_id): - errors.append(f"Chat({chat_id}) not found.") - continue - success_count += DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}) - - all_errors = errors + duplicate_messages - if all_errors: - if success_count > 0: - return get_json_result( - data={"success_count": success_count, "errors": all_errors}, - message=f"Partially deleted {success_count} chats with {len(all_errors)} errors", - ) - return get_data_error_result(message="; ".join(all_errors)) - - return get_json_result(data={"success_count": success_count}) \ No newline at end of file diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py new file mode 100644 index 000000000..aad3fb980 --- /dev/null +++ b/api/apps/sdk/chat.py @@ -0,0 +1,329 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from quart import request +from api.db.services.dialog_service import DialogService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.tenant_llm_service import TenantLLMService +from api.db.services.user_service import TenantService +from common.misc_utils import get_uuid +from common.constants import RetCode, StatusEnum +from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json + + +@manager.route("/chats", methods=["POST"]) # noqa: F821 +@token_required +async def create(tenant_id): + req = await get_request_json() + ids = [i for i in req.get("dataset_ids", []) if i] + for kb_id in ids: + kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) + if not kbs: + return get_error_data_result(f"You don't own the dataset {kb_id}") + kbs = KnowledgebaseService.query(id=kb_id) + kb = kbs[0] + if kb.chunk_num == 0: + return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") + + kbs = KnowledgebaseService.get_by_ids(ids) if ids else [] + embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison + embd_count = list(set(embd_ids)) + if len(embd_count) > 1: + return get_result(message='Datasets use different embedding models."', code=RetCode.AUTHENTICATION_ERROR) + req["kb_ids"] = ids + # llm + llm = req.get("llm") + if llm: + if "model_name" in llm: + req["llm_id"] = llm.pop("model_name") + if req.get("llm_id") is not None: + llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"]) + model_type = llm.get("model_type") + model_type = model_type if model_type in ["chat", "image2text"] else "chat" + if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type): + return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") + req["llm_setting"] = req.pop("llm") + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + return get_error_data_result(message="Tenant not found!") + # prompt + prompt = req.get("prompt") + key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"} + key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"] + if prompt: + for new_key, old_key in key_mapping.items(): + if old_key in prompt: + prompt[new_key] = prompt.pop(old_key) + for key in key_list: + if key in prompt: + req[key] = prompt.pop(key) + req["prompt_config"] = req.pop("prompt") + # init + req["id"] = get_uuid() + req["description"] = req.get("description", "A helpful Assistant") + req["icon"] = req.get("avatar", "") + req["top_n"] = req.get("top_n", 6) + req["top_k"] = req.get("top_k", 1024) + req["rerank_id"] = req.get("rerank_id", "") + if req.get("rerank_id"): + value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] + if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"): + return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") + if not req.get("llm_id"): + req["llm_id"] = tenant.llm_id + if not req.get("name"): + return get_error_data_result(message="`name` is required.") + if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): + return get_error_data_result(message="Duplicated chat name in creating chat.") + # tenant_id + if req.get("tenant_id"): + return get_error_data_result(message="`tenant_id` must not be provided.") + req["tenant_id"] = tenant_id + # prompt more parameter + default_prompt = { + "system": """You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" Answers need to consider chat history. + Here is the knowledge base: + {knowledge} + The above is the knowledge base.""", + "prologue": "Hi! I'm your assistant. What can I do for you?", + "parameters": [{"key": "knowledge", "optional": False}], + "empty_response": "Sorry! No relevant content was found in the knowledge base!", + "quote": True, + "tts": False, + "refine_multiturn": True, + } + key_list_2 = ["system", "prologue", "parameters", "empty_response", "quote", "tts", "refine_multiturn"] + if "prompt_config" not in req: + req["prompt_config"] = {} + for key in key_list_2: + temp = req["prompt_config"].get(key) + if (not temp and key == "system") or (key not in req["prompt_config"]): + req["prompt_config"][key] = default_prompt[key] + for p in req["prompt_config"]["parameters"]: + if p["optional"]: + continue + if req["prompt_config"]["system"].find("{%s}" % p["key"]) < 0: + return get_error_data_result(message="Parameter '{}' is not used".format(p["key"])) + # save + if not DialogService.save(**req): + return get_error_data_result(message="Fail to new a chat!") + # response + e, res = DialogService.get_by_id(req["id"]) + if not e: + return get_error_data_result(message="Fail to new a chat!") + res = res.to_json() + renamed_dict = {} + for key, value in res["prompt_config"].items(): + new_key = key_mapping.get(key, key) + renamed_dict[new_key] = value + res["prompt"] = renamed_dict + del res["prompt_config"] + new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]} + res["prompt"].update(new_dict) + for key in key_list: + del res[key] + res["llm"] = res.pop("llm_setting") + res["llm"]["model_name"] = res.pop("llm_id") + del res["kb_ids"] + res["dataset_ids"] = req.get("dataset_ids", []) + res["avatar"] = res.pop("icon") + return get_result(data=res) + + +@manager.route("/chats/", methods=["PUT"]) # noqa: F821 +@token_required +async def update(tenant_id, chat_id): + if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): + return get_error_data_result(message="You do not own the chat") + req = await get_request_json() + ids = req.get("dataset_ids", []) + if "show_quotation" in req: + req["do_refer"] = req.pop("show_quotation") + if ids: + for kb_id in ids: + kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) + if not kbs: + return get_error_data_result(f"You don't own the dataset {kb_id}") + kbs = KnowledgebaseService.query(id=kb_id) + kb = kbs[0] + if kb.chunk_num == 0: + return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") + + kbs = KnowledgebaseService.get_by_ids(ids) + embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison + embd_count = list(set(embd_ids)) + if len(embd_count) > 1: + return get_result(message='Datasets use different embedding models."', code=RetCode.AUTHENTICATION_ERROR) + req["kb_ids"] = ids + else: + req["kb_ids"] = [] + llm = req.get("llm") + if llm: + if "model_name" in llm: + req["llm_id"] = llm.pop("model_name") + if req.get("llm_id") is not None: + llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"]) + model_type = llm.get("model_type") + model_type = model_type if model_type in ["chat", "image2text"] else "chat" + if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type): + return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") + req["llm_setting"] = req.pop("llm") + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + return get_error_data_result(message="Tenant not found!") + # prompt + prompt = req.get("prompt") + key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"} + key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"] + if prompt: + for new_key, old_key in key_mapping.items(): + if old_key in prompt: + prompt[new_key] = prompt.pop(old_key) + for key in key_list: + if key in prompt: + req[key] = prompt.pop(key) + req["prompt_config"] = req.pop("prompt") + e, res = DialogService.get_by_id(chat_id) + res = res.to_json() + if req.get("rerank_id"): + value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] + if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"): + return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") + if "name" in req: + if not req.get("name"): + return get_error_data_result(message="`name` cannot be empty.") + if req["name"].lower() != res["name"].lower() and len(DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: + return get_error_data_result(message="Duplicated chat name in updating chat.") + if "prompt_config" in req: + res["prompt_config"].update(req["prompt_config"]) + for p in res["prompt_config"]["parameters"]: + if p["optional"]: + continue + if res["prompt_config"]["system"].find("{%s}" % p["key"]) < 0: + return get_error_data_result(message="Parameter '{}' is not used".format(p["key"])) + if "llm_setting" in req: + res["llm_setting"].update(req["llm_setting"]) + req["prompt_config"] = res["prompt_config"] + req["llm_setting"] = res["llm_setting"] + # avatar + if "avatar" in req: + req["icon"] = req.pop("avatar") + if "dataset_ids" in req: + req.pop("dataset_ids") + if not DialogService.update_by_id(chat_id, req): + return get_error_data_result(message="Chat not found!") + return get_result() + + +@manager.route("/chats", methods=["DELETE"]) # noqa: F821 +@token_required +async def delete_chats(tenant_id): + errors = [] + success_count = 0 + req = await get_request_json() + if not req: + return get_result() + + ids = req.get("ids") + if not ids: + if req.get("delete_all") is True: + ids = [d.id for d in DialogService.query(tenant_id=tenant_id, status=StatusEnum.VALID.value)] + if not ids: + return get_result() + else: + return get_result() + + id_list = ids + + unique_id_list, duplicate_messages = check_duplicate_ids(id_list, "assistant") + + for id in unique_id_list: + if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value): + errors.append(f"Assistant({id}) not found.") + continue + temp_dict = {"status": StatusEnum.INVALID.value} + success_count += DialogService.update_by_id(id, temp_dict) + + if errors: + if success_count > 0: + return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} chats with {len(errors)} errors") + else: + return get_error_data_result(message="; ".join(errors)) + + if duplicate_messages: + if success_count > 0: + return get_result(message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages}) + else: + return get_error_data_result(message=";".join(duplicate_messages)) + + return get_result() + + +@manager.route("/chats", methods=["GET"]) # noqa: F821 +@token_required +def list_chat(tenant_id): + id = request.args.get("id") + name = request.args.get("name") + if id or name: + chat = DialogService.query(id=id, name=name, status=StatusEnum.VALID.value, tenant_id=tenant_id) + if not chat: + return get_error_data_result(message="The chat doesn't exist") + page_number = int(request.args.get("page", 1)) + items_per_page = int(request.args.get("page_size", 30)) + orderby = request.args.get("orderby", "create_time") + if request.args.get("desc") == "False" or request.args.get("desc") == "false": + desc = False + else: + desc = True + chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name) + if not chats: + return get_result(data=[]) + list_assistants = [] + key_mapping = { + "parameters": "variables", + "prologue": "opener", + "quote": "show_quote", + "system": "prompt", + "rerank_id": "rerank_model", + "vector_similarity_weight": "keywords_similarity_weight", + "do_refer": "show_quotation", + } + key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"] + for res in chats: + renamed_dict = {} + for key, value in res["prompt_config"].items(): + new_key = key_mapping.get(key, key) + renamed_dict[new_key] = value + res["prompt"] = renamed_dict + del res["prompt_config"] + new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]} + res["prompt"].update(new_dict) + for key in key_list: + del res[key] + res["llm"] = res.pop("llm_setting") + res["llm"]["model_name"] = res.pop("llm_id") + kb_list = [] + for kb_id in res["kb_ids"]: + kb = KnowledgebaseService.query(id=kb_id) + if not kb: + logging.warning(f"The kb {kb_id} does not exist.") + continue + kb_list.append(kb[0].to_json()) + del res["kb_ids"] + res["datasets"] = kb_list + res["avatar"] = res.pop("icon") + list_assistants.append(res) + return get_result(data=list_assistants) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 83f79c285..dcd7cd6af 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -105,18 +105,7 @@ class DialogService(CommonService): @classmethod @DB.connection_context() - def get_by_tenant_ids( - cls, - joined_tenant_ids, - user_id, - page_number, - items_per_page, - orderby, - desc, - keywords, - id=None, - name=None, - ): + def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None): from api.db.db_models import User fields = [ @@ -143,20 +132,25 @@ class DialogService(CommonService): cls.model.update_time, cls.model.create_time, ] - dialogs = ( - cls.model.select(*fields) - .join(User, on=(cls.model.tenant_id == User.id)) - .where( - (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) - & (cls.model.status == StatusEnum.VALID.value), - ) - ) - if id: - dialogs = dialogs.where(cls.model.id == id) - if name: - dialogs = dialogs.where(cls.model.name == name) if keywords: - dialogs = dialogs.where(fn.LOWER(cls.model.name).contains(keywords.lower())) + dialogs = ( + cls.model.select(*fields) + .join(User, on=(cls.model.tenant_id == User.id)) + .where( + (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value), + (fn.LOWER(cls.model.name).contains(keywords.lower())), + ) + ) + else: + dialogs = ( + cls.model.select(*fields) + .join(User, on=(cls.model.tenant_id == User.id)) + .where( + (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value), + ) + ) + if parser_id: + dialogs = dialogs.where(cls.model.parser_id == parser_id) if desc: dialogs = dialogs.order_by(cls.model.getter_by(orderby).desc()) else: diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index e3f74b40c..df237c255 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -2756,11 +2756,10 @@ Creates a chat assistant. - `'Authorization: Bearer '` - Body: - `"name"`: `string` - - `"icon"`: `string` + - `"avatar"`: `string` - `"dataset_ids"`: `list[string]` - - `"llm_id"`: `string` - - `"llm_setting"`: `object` - - `"prompt_config"`: `object` + - `"llm"`: `object` + - `"prompt"`: `object` ##### Request example @@ -2779,16 +2778,27 @@ curl --request POST \ - `"name"`: (*Body parameter*), `string`, *Required* The name of the chat assistant. -- `"icon"`: (*Body parameter*), `string` +- `"avatar"`: (*Body parameter*), `string` Base64 encoding of the avatar. -- `"dataset_ids"`: (*Body parameter*), `list[string]` - The IDs of the associated datasets. If omitted or set to `[]`, an empty chat assistant is created and datasets can be attached later. -- `"llm_id"`: (*Body parameter*), `string` - The chat model name. If not set, the user's default chat model will be used. -- `"llm_setting"`: (*Body parameter*), `object` - The LLM settings for the chat assistant to create. An `llm_setting` object may contain the following attributes: +- `"dataset_ids"`: (*Body parameter*), `list[string]` + The IDs of the associated datasets. +- `"llm"`: (*Body parameter*), `object` + The LLM settings for the chat assistant to create. If it is not explicitly set, a JSON object with the following values will be generated as the default. An `llm` JSON object contains the following attributes: + - `"model_name"`, `string` + The chat model name. If not set, the user's default chat model will be used. + + :::caution WARNING + `model_type` is an *internal* parameter, serving solely as a temporary workaround for the current model-configuration design limitations. + + Its main purpose is to let *multimodal* models (stored in the database as `"image2text"`) pass backend validation/dispatching. Be mindful that: + + - Do *not* treat it as a stable public API. + - It is subject to change or removal in future releases. + ::: + - `"model_type"`: `string` A model type specifier. Only `"chat"` and `"image2text"` are recognized; any other inputs, or when omitted, are treated as `"chat"`. + - `"model_name"`, `string` - `"temperature"`: `float` Controls the randomness of the model's predictions. A lower temperature results in more conservative responses, while a higher temperature yields more creative and diverse responses. Defaults to `0.1`. - `"top_p"`: `float` @@ -2797,27 +2807,21 @@ curl --request POST \ This discourages the model from repeating the same information by penalizing words that have already appeared in the conversation. Defaults to `0.4`. - `"frequency penalty"`: `float` Similar to the presence penalty, this reduces the model’s tendency to repeat the same words frequently. Defaults to `0.7`. -- `"prompt_config"`: (*Body parameter*), `object` - Instructions for the LLM to follow. A `prompt_config` object may contain the following attributes: - - `"system"`: `string` The prompt content. - - `"prologue"`: `string` The opening greeting for the user. - - `"parameters"`: `object[]` This argument lists the variables to use in the system prompt. Note that: +- `"prompt"`: (*Body parameter*), `object` + Instructions for the LLM to follow. If it is not explicitly set, a JSON object with the following values will be generated as the default. A `prompt` JSON object contains the following attributes: + - `"similarity_threshold"`: `float` RAGFlow employs either a combination of weighted keyword similarity and weighted vector cosine similarity, or a combination of weighted keyword similarity and weighted reranking score during retrieval. This argument sets the threshold for similarities between the user query and chunks. If a similarity score falls below this threshold, the corresponding chunk will be excluded from the results. The default value is `0.2`. + - `"keywords_similarity_weight"`: `float` This argument sets the weight of keyword similarity in the hybrid similarity score with vector cosine similarity or reranking model similarity. By adjusting this weight, you can control the influence of keyword similarity in relation to other similarity measures. The default value is `0.7`. + - `"top_n"`: `int` This argument specifies the number of top chunks with similarity scores above the `similarity_threshold` that are fed to the LLM. The LLM will *only* access these 'top N' chunks. The default value is `6`. + - `"variables"`: `object[]` This argument lists the variables to use in the 'System' field of **Chat Configurations**. Note that: - `"knowledge"` is a reserved variable, which represents the retrieved chunks. - - All the variables in `"system"` should be curly bracketed. + - All the variables in 'System' should be curly bracketed. + - The default value is `[{"key": "knowledge", "optional": true}]`. + - `"rerank_model"`: `string` If it is not specified, vector cosine similarity will be used; otherwise, reranking score will be used. + - `top_k`: `int` Refers to the process of reordering or selecting the top-k items from a list or set based on a specific ranking criterion. Default to 1024. - `"empty_response"`: `string` If nothing is retrieved in the dataset for the user's question, this will be used as the response. To allow the LLM to improvise when nothing is found, leave this blank. - - `"quote"`: `boolean` Indicates whether the source of text should be displayed. Defaults to `true`. - - `"tts"`: `boolean` - - `"refine_multiturn"`: `boolean` - - `"use_kg"`: `boolean` - - `"reasoning"`: `boolean` - - `"cross_languages"`: `list[string]` - - `"tavily_api_key"`: `string` - - `"toc_enhance"`: `boolean` -- `"similarity_threshold"`: (*Body parameter*), `float` -- `"vector_similarity_weight"`: (*Body parameter*), `float` -- `"top_n"`: (*Body parameter*), `int` -- `"top_k"`: (*Body parameter*), `int` -- `"rerank_id"`: (*Body parameter*), `string` + - `"opener"`: `string` The opening greeting for the user. Defaults to `"Hi! I am your assistant, can I help you?"`. + - `"show_quote`: `boolean` Indicates whether the source of text should be displayed. Defaults to `true`. + - `"prompt"`: `string` The prompt content. #### Response @@ -2827,42 +2831,39 @@ Success: { "code": 0, "data": { - "icon": "", + "avatar": "", "create_date": "Thu, 24 Oct 2024 11:18:29 GMT", "create_time": 1729768709023, "dataset_ids": [ "527fa74891e811ef9c650242ac120006" ], - "kb_names": [ - "dataset_1" - ], "description": "A helpful Assistant", + "do_refer": "1", "id": "b1f2f15691f911ef81180242ac120003", "language": "English", - "llm_id": "qwen-plus@Tongyi-Qianwen", - "llm_setting": { + "llm": { "frequency_penalty": 0.7, + "model_name": "qwen-plus@Tongyi-Qianwen", "presence_penalty": 0.4, "temperature": 0.1, "top_p": 0.3 }, "name": "12234", - "prompt_config": { + "prompt": { "empty_response": "Sorry! No relevant content was found in the knowledge base!", - "prologue": "Hi! I'm your assistant. What can I do for you?", - "quote": true, - "system": "You are an intelligent assistant...", - "parameters": [ + "keywords_similarity_weight": 0.3, + "opener": "Hi! I'm your assistant. What can I do for you?", + "prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n ", + "rerank_model": "", + "similarity_threshold": 0.2, + "top_n": 6, + "variables": [ { "key": "knowledge", "optional": false } ] }, - "rerank_id": "", - "similarity_threshold": 0.2, - "vector_similarity_weight": 0.3, - "top_n": 6, "prompt_type": "simple", "status": "1", "tenant_id": "69736c5e723611efb51b0242ac120007", @@ -2878,7 +2879,7 @@ Failure: ```json { "code": 102, - "message": "Duplicated chat name." + "message": "Duplicated chat name in creating dataset." } ``` @@ -2888,9 +2889,7 @@ Failure: **PUT** `/api/v1/chats/{chat_id}` -Replaces the persisted configuration of a specified chat assistant. - -Use this endpoint only when you intend to send the full configuration to keep. Omitted fields are reset to server defaults. For partial updates, use `PATCH /api/v1/chats/{chat_id}` instead. +Updates configurations for a specified chat assistant. #### Request @@ -2901,11 +2900,10 @@ Use this endpoint only when you intend to send the full configuration to keep. O - `'Authorization: Bearer '` - Body: - `"name"`: `string` - - `"icon"`: `string` + - `"avatar"`: `string` - `"dataset_ids"`: `list[string]` - - `"llm_id"`: `string` - - `"llm_setting"`: `object` - - `"prompt_config"`: `object` + - `"llm"`: `object` + - `"prompt"`: `object` ##### Request example @@ -2916,23 +2914,7 @@ curl --request PUT \ --header 'Authorization: Bearer ' \ --data ' { - "name":"Test", - "icon":"", - "dataset_ids":["0b2cbc8c877f11ef89070242ac120005"], - "llm_id":"qwen-plus@Tongyi-Qianwen", - "llm_setting":{"temperature":0.1,"top_p":0.3,"presence_penalty":0.4,"frequency_penalty":0.7}, - "prompt_config":{ - "system":"You are an intelligent assistant...", - "prologue":"Hi! I'\''m your assistant. What can I do for you?", - "parameters":[{"key":"knowledge","optional":false}], - "empty_response":"Sorry! No relevant content was found in the knowledge base!", - "quote":true - }, - "similarity_threshold":0.2, - "vector_similarity_weight":0.3, - "top_n":6, - "top_k":1024, - "rerank_id":"" + "name":"Test" }' ``` @@ -2942,110 +2924,36 @@ curl --request PUT \ The ID of the chat assistant to update. - `"name"`: (*Body parameter*), `string`, *Required* The revised name of the chat assistant. -- `"icon"`: (*Body parameter*), `string` +- `"avatar"`: (*Body parameter*), `string` Base64 encoding of the avatar. -- `"dataset_ids"`: (*Body parameter*), `list[string]` +- `"dataset_ids"`: (*Body parameter*), `list[string]` The IDs of the associated datasets. -- `"llm_id"`: (*Body parameter*), `string` - The chat model name. If not set, the user's default chat model will be used. -- `"llm_setting"`: (*Body parameter*), `object` - The LLM settings for the chat assistant. An `llm_setting` object contains the following attributes: - - `"model_type"`: `string` - A model type specifier. Only `"chat"` and `"image2text"` are recognized; any other inputs, or when omitted, are treated as `"chat"`. +- `"llm"`: (*Body parameter*), `object` + The LLM settings for the chat assistant to create. If it is not explicitly set, a dictionary with the following values will be generated as the default. An `llm` object contains the following attributes: + - `"model_name"`, `string` + The chat model name. If not set, the user's default chat model will be used. - `"temperature"`: `float` Controls the randomness of the model's predictions. A lower temperature results in more conservative responses, while a higher temperature yields more creative and diverse responses. Defaults to `0.1`. - `"top_p"`: `float` Also known as “nucleus sampling”, this parameter sets a threshold to select a smaller set of words to sample from. It focuses on the most likely words, cutting off the less probable ones. Defaults to `0.3` - `"presence_penalty"`: `float` - This discourages the model from repeating the same information by penalizing words that have already appeared in the conversation. Defaults to `0.4`. + This discourages the model from repeating the same information by penalizing words that have already appeared in the conversation. Defaults to `0.2`. - `"frequency penalty"`: `float` Similar to the presence penalty, this reduces the model’s tendency to repeat the same words frequently. Defaults to `0.7`. -- `"prompt_config"`: (*Body parameter*), `object` -- `"similarity_threshold"`: (*Body parameter*), `float` -- `"vector_similarity_weight"`: (*Body parameter*), `float` -- `"top_n"`: (*Body parameter*), `int` -- `"top_k"`: (*Body parameter*), `int` -- `"rerank_id"`: (*Body parameter*), `string` - -Any field omitted from the request body is reset to the server-side default value for `PUT`. - -#### Response - -Success: returns the full updated chat assistant object. - -```json -{ - "code": 0, - "data": { - "id": "04d0d8e28d1911efa3630242ac120006", - "name": "Test", - "description": "A helpful Assistant", - "icon": "", - "dataset_ids": ["527fa74891e811ef9c650242ac120006"], - "kb_names": ["dataset_1"], - "llm_id": "qwen-plus@Tongyi-Qianwen", - "llm_setting": { - "frequency_penalty": 0.7, - "presence_penalty": 0.4, - "temperature": 0.1, - "top_p": 0.3 - }, - "prompt_config": { - "empty_response": "Sorry! No relevant content was found in the knowledge base!", - "prologue": "Hi! I'm your assistant. What can I do for you?", - "quote": true, - "system": "You are an intelligent assistant...", - "parameters": [{"key": "knowledge", "optional": false}] - }, - "similarity_threshold": 0.2, - "vector_similarity_weight": 0.3, - "top_n": 6, - "top_k": 1024, - "rerank_id": "", - "status": "1", - "tenant_id": "69736c5e723611efb51b0242ac120007", - "create_time": 1729232406637, - "update_time": 1729232406638 - } -} -``` - -Failure: - -```json -{ - "code": 102, - "message": "Duplicated chat name." -} -``` - ---- - -### Get chat assistant - -**GET** `/api/v1/chats/{chat_id}` - -Retrieves a specified chat assistant. - -#### Request - -- Method: GET -- URL: `/api/v1/chats/{chat_id}` -- Headers: - - `'Authorization: Bearer '` - -##### Request example - -```bash -curl --request GET \ - --url http://{address}/api/v1/chats/{chat_id} \ - --header 'Authorization: Bearer ' -``` - -##### Request parameters - -- `chat_id`: (*Path parameter*) - The ID of the chat assistant to retrieve. +- `"prompt"`: (*Body parameter*), `object` + Instructions for the LLM to follow. A `prompt` object contains the following attributes: + - `"similarity_threshold"`: `float` RAGFlow employs either a combination of weighted keyword similarity and weighted vector cosine similarity, or a combination of weighted keyword similarity and weighted rerank score during retrieval. This argument sets the threshold for similarities between the user query and chunks. If a similarity score falls below this threshold, the corresponding chunk will be excluded from the results. The default value is `0.2`. + - `"keywords_similarity_weight"`: `float` This argument sets the weight of keyword similarity in the hybrid similarity score with vector cosine similarity or reranking model similarity. By adjusting this weight, you can control the influence of keyword similarity in relation to other similarity measures. The default value is `0.7`. + - `"top_n"`: `int` This argument specifies the number of top chunks with similarity scores above the `similarity_threshold` that are fed to the LLM. The LLM will *only* access these 'top N' chunks. The default value is `8`. + - `"variables"`: `object[]` This argument lists the variables to use in the 'System' field of **Chat Configurations**. Note that: + - `"knowledge"` is a reserved variable, which represents the retrieved chunks. + - All the variables in 'System' should be curly bracketed. + - The default value is `[{"key": "knowledge", "optional": true}]` + - `"rerank_model"`: `string` If it is not specified, vector cosine similarity will be used; otherwise, reranking score will be used. + - `"empty_response"`: `string` If nothing is retrieved in the dataset for the user's question, this will be used as the response. To allow the LLM to improvise when nothing is found, leave this blank. + - `"opener"`: `string` The opening greeting for the user. Defaults to `"Hi! I am your assistant, can I help you?"`. + - `"show_quote`: `boolean` Indicates whether the source of text should be displayed. Defaults to `true`. + - `"prompt"`: `string` The prompt content. #### Response @@ -3053,38 +2961,7 @@ Success: ```json { - "code": 0, - "data": { - "icon": "", - "create_date": "Fri, 18 Oct 2024 06:20:06 GMT", - "create_time": 1729232406637, - "description": "A helpful Assistant", - "id": "04d0d8e28d1911efa3630242ac120006", - "dataset_ids": ["527fa74891e811ef9c650242ac120006"], - "kb_names": ["dataset_1"], - "language": "English", - "llm_id": "qwen-plus@Tongyi-Qianwen", - "llm_setting": { - "temperature": 0.1, - "top_p": 0.3 - }, - "name": "my_chat", - "prompt_config": { - "empty_response": "Sorry! No relevant content was found in the knowledge base!", - "prologue": "Hi! I'm your assistant. What can I do for you?", - "quote": true, - "system": "You are an intelligent assistant...", - "parameters": [{"key": "knowledge", "optional": false}] - }, - "rerank_id": "", - "similarity_threshold": 0.2, - "vector_similarity_weight": 0.3, - "top_n": 6, - "status": "1", - "tenant_id": "69736c5e723611efb51b0242ac120007", - "update_date": "Fri, 18 Oct 2024 06:20:06 GMT", - "update_time": 1729232406638 - } + "code": 0 } ``` @@ -3093,112 +2970,7 @@ Failure: ```json { "code": 102, - "message": "No authorization." -} -``` - ---- - -### Partially update chat assistant - -**PATCH** `/api/v1/chats/{chat_id}` - -Partially updates a specified chat assistant. - -This endpoint preserves unspecified fields. Nested `llm_setting` and `prompt_config` objects are deep-merged with the existing configuration, so it is the recommended endpoint for renaming a chat assistant or updating only a subset of settings. - -#### Request - -- Method: PATCH -- URL: `/api/v1/chats/{chat_id}` -- Headers: - - `'content-Type: application/json'` - - `'Authorization: Bearer '` -- Body: any subset of the fields accepted by `PUT /api/v1/chats/{chat_id}` - -##### Request example - -```bash -curl --request PATCH \ - --url http://{address}/api/v1/chats/{chat_id} \ - --header 'Content-Type: application/json' \ - --header 'Authorization: Bearer ' \ - --data '{ - "llm_id": "gpt-4o", - "llm_setting": {"temperature": 0.5} -}' -``` - -#### Response - -Success: returns the full updated chat assistant object (same structure as `PUT /api/v1/chats/{chat_id}`). - -```json -{ - "code": 0, - "data": { - "id": "04d0d8e28d1911efa3630242ac120006", - "name": "Renamed assistant", - "llm_id": "qwen-plus@Tongyi-Qianwen", - "..." : "..." - } -} -``` - -Failure: - -```json -{ - "code": 102, - "message": "No authorization." -} -``` - ---- - -### Delete chat assistant - -**DELETE** `/api/v1/chats/{chat_id}` - -Deletes a single chat assistant by ID. - -#### Request - -- Method: DELETE -- URL: `/api/v1/chats/{chat_id}` -- Headers: - - `'Authorization: Bearer '` - -##### Request example - -```bash -curl --request DELETE \ - --url http://{address}/api/v1/chats/{chat_id} \ - --header 'Authorization: Bearer ' -``` - -##### Request parameters - -- `chat_id`: (*Path parameter*) - The ID of the chat assistant to delete. - -#### Response - -Success: - -```json -{ - "code": 0, - "data": true -} -``` - -Failure: - -```json -{ - "code": 102, - "message": "No authorization." + "message": "Duplicated chat name in updating dataset." } ``` @@ -3276,14 +3048,14 @@ Failure: ### List chat assistants -**GET** `/api/v1/chats?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&keywords={keywords}&owner_ids={owner_id}&name={chat_name}&id={chat_id}` +**GET** `/api/v1/chats?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={chat_name}&id={chat_id}` Lists chat assistants. #### Request - Method: GET -- URL: `/api/v1/chats?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&keywords={keywords}&owner_ids={owner_id}&name={chat_name}&id={chat_id}` +- URL: `/api/v1/chats?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={chat_name}&id={chat_id}` - Headers: - `'Authorization: Bearer '` @@ -3291,32 +3063,26 @@ Lists chat assistants. ```bash curl --request GET \ - --url http://{address}/api/v1/chats?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&keywords={keywords}&owner_ids={owner_id}&name={chat_name}&id={chat_id} \ + --url http://{address}/api/v1/chats?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={chat_name}&id={chat_id} \ --header 'Authorization: Bearer ' ``` ##### Request parameters -- `page`: (*Filter parameter*), `integer` +- `page`: (*Filter parameter*), `integer` Specifies the page on which the chat assistants will be displayed. Defaults to `1`. -- `page_size`: (*Filter parameter*), `integer` +- `page_size`: (*Filter parameter*), `integer` The number of chat assistants on each page. Defaults to `30`. -- `orderby`: (*Filter parameter*), `string` +- `orderby`: (*Filter parameter*), `string` The attribute by which the results are sorted. Available options: - `create_time` (default) - `update_time` -- `desc`: (*Filter parameter*), `boolean` +- `desc`: (*Filter parameter*), `boolean` Indicates whether the retrieved chat assistants should be sorted in descending order. Defaults to `true`. -- `keywords`: (*Filter parameter*), `string` - Case-insensitive fuzzy match against chat assistant names. -- `owner_ids`: (*Filter parameter*), `string` (repeatable) - Filter by owner tenant IDs. Can be specified multiple times: `?owner_ids=id1&owner_ids=id2`. -- `id`: (*Filter parameter*), `string` - The ID of the chat assistant to retrieve with exact match. -- `name`: (*Filter parameter*), `string` - The name of the chat assistant to retrieve with exact match. - -When `id` or `name` is provided, exact filtering takes precedence over `keywords`. +- `id`: (*Filter parameter*), `string` + The ID of the chat assistant to retrieve. +- `name`: (*Filter parameter*), `string` + The name of the chat assistant to retrieve. #### Response @@ -3325,50 +3091,47 @@ Success: ```json { "code": 0, - "data": { - "chats": [ - { - "icon": "", - "create_date": "Fri, 18 Oct 2024 06:20:06 GMT", - "create_time": 1729232406637, - "description": "A helpful Assistant", - "id": "04d0d8e28d1911efa3630242ac120006", - "dataset_ids": ["527fa74891e811ef9c650242ac120006"], - "kb_names": ["dataset_1"], - "language": "English", - "llm_id": "qwen-plus@Tongyi-Qianwen", - "llm_setting": { - "frequency_penalty": 0.7, - "presence_penalty": 0.4, - "temperature": 0.1, - "top_p": 0.3 - }, - "name": "13243", - "prompt_config": { - "empty_response": "Sorry! No relevant content was found in the knowledge base!", - "prologue": "Hi! I'm your assistant. What can I do for you?", - "quote": true, - "system": "You are an intelligent assistant...", - "parameters": [ - { - "key": "knowledge", - "optional": false - } - ] - }, - "rerank_id": "", + "data": [ + { + "avatar": "", + "create_date": "Fri, 18 Oct 2024 06:20:06 GMT", + "create_time": 1729232406637, + "description": "A helpful Assistant", + "do_refer": "1", + "id": "04d0d8e28d1911efa3630242ac120006", + "dataset_ids": ["527fa74891e811ef9c650242ac120006"], + "language": "English", + "llm": { + "frequency_penalty": 0.7, + "model_name": "qwen-plus@Tongyi-Qianwen", + "presence_penalty": 0.4, + "temperature": 0.1, + "top_p": 0.3 + }, + "name": "13243", + "prompt": { + "empty_response": "Sorry! No relevant content was found in the knowledge base!", + "keywords_similarity_weight": 0.3, + "opener": "Hi! I'm your assistant. What can I do for you?", + "prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n", + "rerank_model": "", "similarity_threshold": 0.2, - "vector_similarity_weight": 0.3, "top_n": 6, - "prompt_type": "simple", - "status": "1", - "tenant_id": "69736c5e723611efb51b0242ac120007", - "update_date": "Fri, 18 Oct 2024 06:20:06 GMT", - "update_time": 1729232406638 - } - ], - "total": 1 - } + "variables": [ + { + "key": "knowledge", + "optional": false + } + ] + }, + "prompt_type": "simple", + "status": "1", + "tenant_id": "69736c5e723611efb51b0242ac120007", + "top_k": 1024, + "update_date": "Fri, 18 Oct 2024 06:20:06 GMT", + "update_time": 1729232406638 + } + ] } ``` diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index 2bc1a56b8..a03c7c7c4 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -1149,13 +1149,11 @@ for c in rag_object.retrieve(dataset_ids=[dataset.id],document_ids=[doc.id]): ```python RAGFlow.create_chat( - name: str, - icon: str = “”, - dataset_ids: list[str] | None = None, - llm_id: str | None = None, - llm_setting: dict | None = None, - prompt_config: dict | None = None, - **kwargs + name: str, + avatar: str = "", + dataset_ids: list[str] = [], + llm: Chat.LLM = None, + prompt: Chat.Prompt = None ) -> Chat ``` @@ -1167,37 +1165,46 @@ Creates a chat assistant. The name of the chat assistant. -##### icon: `str` +##### avatar: `str` -Base64 encoding of the avatar. Defaults to `””`. +Base64 encoding of the avatar. Defaults to `""`. ##### dataset_ids: `list[str]` -The IDs of the associated datasets. Defaults to `[]`. When omitted or empty, the SDK creates an empty chat assistant and you can attach datasets later. +The IDs of the associated datasets. Defaults to `[""]`. -##### llm_id: `str | None` +##### llm: `Chat.LLM` -The LLM model name/ID to use. If `None`, the user’s default chat model is used. Defaults to `None`. +The LLM settings for the chat assistant to create. Defaults to `None`. When the value is `None`, a dictionary with the following values will be generated as the default. An `LLM` object contains the following attributes: -##### llm_setting: `dict | None` +- `model_name`: `str` + The chat model name. If it is `None`, the user's default chat model will be used. +- `temperature`: `float` + Controls the randomness of the model's predictions. A lower temperature results in more conservative responses, while a higher temperature yields more creative and diverse responses. Defaults to `0.1`. +- `top_p`: `float` + Also known as “nucleus sampling”, this parameter sets a threshold to select a smaller set of words to sample from. It focuses on the most likely words, cutting off the less probable ones. Defaults to `0.3` +- `presence_penalty`: `float` + This discourages the model from repeating the same information by penalizing words that have already appeared in the conversation. Defaults to `0.2`. +- `frequency penalty`: `float` + Similar to the presence penalty, this reduces the model’s tendency to repeat the same words frequently. Defaults to `0.7`. -LLM generation settings. Defaults to `None` (server defaults apply). Supported keys: +##### prompt: `Chat.Prompt` -- `”temperature”`: `float` Controls the randomness of the model’s predictions. Defaults to `0.1`. -- `”top_p”`: `float` Nucleus sampling threshold. Defaults to `0.3`. -- `”presence_penalty”`: `float` Penalizes tokens that have already appeared. Defaults to `0.4`. -- `”frequency_penalty”`: `float` Reduces repetition of frequent tokens. Defaults to `0.7`. -- `”max_token”`: `int` Maximum number of tokens in the response. Defaults to `512`. +Instructions for the LLM to follow. A `Prompt` object contains the following attributes: -##### prompt_config: `dict | None` - -Instructions for the LLM to follow. Defaults to `None` (server defaults apply). Supported keys: - -- `”system”`: `str` The system prompt content. -- `”empty_response”`: `str` Response when nothing is retrieved. Leave blank to let the LLM improvise. Defaults to `None`. -- `”prologue”`: `str` The opening greeting shown to the user. Defaults to `”Hi! I’m your assistant. What can I do for you?”`. -- `”quote”`: `bool` Whether to display source references. Defaults to `True`. -- `”parameters”`: `list[dict]` Variables used in the system prompt. Each entry has `”key”` (`str`) and `”optional”` (`bool`). The `knowledge` variable is reserved for retrieved chunks. Default: `[{“key”: “knowledge”, “optional”: True}]`. +- `similarity_threshold`: `float` RAGFlow employs either a combination of weighted keyword similarity and weighted vector cosine similarity, or a combination of weighted keyword similarity and weighted reranking score during retrieval. If a similarity score falls below this threshold, the corresponding chunk will be excluded from the results. The default value is `0.2`. +- `keywords_similarity_weight`: `float` This argument sets the weight of keyword similarity in the hybrid similarity score with vector cosine similarity or reranking model similarity. By adjusting this weight, you can control the influence of keyword similarity in relation to other similarity measures. The default value is `0.7`. +- `top_n`: `int` This argument specifies the number of top chunks with similarity scores above the `similarity_threshold` that are fed to the LLM. The LLM will *only* access these 'top N' chunks. The default value is `8`. +- `variables`: `list[dict[]]` This argument lists the variables to use in the 'System' field of **Chat Configurations**. Note that: + - `knowledge` is a reserved variable, which represents the retrieved chunks. + - All the variables in 'System' should be curly bracketed. + - The default value is `[{"key": "knowledge", "optional": True}]`. +- `rerank_model`: `str` If it is not specified, vector cosine similarity will be used; otherwise, reranking score will be used. Defaults to `""`. +- `top_k`: `int` Refers to the process of reordering or selecting the top-k items from a list or set based on a specific ranking criterion. Default to 1024. +- `empty_response`: `str` If nothing is retrieved in the dataset for the user's question, this will be used as the response. To allow the LLM to improvise when nothing is found, leave this blank. Defaults to `None`. +- `opener`: `str` The opening greeting for the user. Defaults to `"Hi! I am your assistant, can I help you?"`. +- `show_quote`: `bool` Indicates whether the source of text should be displayed. Defaults to `True`. +- `prompt`: `str` The prompt content. #### Returns @@ -1225,37 +1232,36 @@ assistant = rag_object.create_chat("Miss R", dataset_ids=dataset_ids) Chat.update(update_message: dict) ``` -Partially updates configurations for the current chat assistant. - -`Chat.update()` uses `PATCH /api/v1/chats/{chat_id}`. Only the provided keys are changed; all other fields are preserved. +Updates configurations for the current chat assistant. #### Parameters -##### update_message: `dict`, *Required* +##### update_message: `dict[str, str|list[str]|dict[]]`, *Required* -A dictionary representing the attributes to update. Supported keys: +A dictionary representing the attributes to update, with the following keys: -- `”name”`: `str` The revised name of the chat assistant. -- `”icon”`: `str` Base64 encoding of the avatar. -- `”dataset_ids”`: `list[str]` The datasets to associate with the chat assistant. -- `”llm_id”`: `str` The LLM model name/ID to use. -- `”llm_setting”`: `dict` LLM generation settings: - - `”temperature”`: `float` Controls the randomness of the model’s predictions. - - `”top_p”`: `float` Nucleus sampling threshold. - - `”presence_penalty”`: `float` Penalizes tokens that have already appeared. - - `”frequency_penalty”`: `float` Reduces repetition of frequent tokens. - - `”max_token”`: `int` Maximum number of tokens in the response. -- `”prompt_config”`: `dict` Instructions for the LLM to follow: - - `”system”`: `str` The system prompt content. - - `”empty_response”`: `str` Response when nothing is retrieved. Leave blank to let the LLM improvise. - - `”prologue”`: `str` The opening greeting shown to the user. - - `”quote”`: `bool` Whether to display source references. - - `”parameters”`: `list[dict]` Variables used in the system prompt. -- `”similarity_threshold”`: `float` Minimum similarity score for retrieved chunks. Defaults to `0.2`. -- `”vector_similarity_weight”`: `float` Weight of vector cosine similarity in the hybrid score. Defaults to `0.3`. -- `”top_n”`: `int` Number of top chunks fed to the LLM. Defaults to `6`. -- `”top_k”`: `int` Candidate pool size for reranking. Defaults to `1024`. -- `”rerank_id”`: `str` Reranking model ID. If empty, vector cosine similarity is used. +- `"name"`: `str` The revised name of the chat assistant. +- `"avatar"`: `str` Base64 encoding of the avatar. Defaults to `""` +- `"dataset_ids"`: `list[str]` The datasets to update. +- `"llm"`: `dict` The LLM settings: + - `"model_name"`, `str` The chat model name. + - `"temperature"`, `float` Controls the randomness of the model's predictions. A lower temperature results in more conservative responses, while a higher temperature yields more creative and diverse responses. + - `"top_p"`, `float` Also known as “nucleus sampling”, this parameter sets a threshold to select a smaller set of words to sample from. + - `"presence_penalty"`, `float` This discourages the model from repeating the same information by penalizing words that have appeared in the conversation. + - `"frequency penalty"`, `float` Similar to presence penalty, this reduces the model’s tendency to repeat the same words. +- `"prompt"` : Instructions for the LLM to follow. + - `"similarity_threshold"`: `float` RAGFlow employs either a combination of weighted keyword similarity and weighted vector cosine similarity, or a combination of weighted keyword similarity and weighted rerank score during retrieval. This argument sets the threshold for similarities between the user query and chunks. If a similarity score falls below this threshold, the corresponding chunk will be excluded from the results. The default value is `0.2`. + - `"keywords_similarity_weight"`: `float` This argument sets the weight of keyword similarity in the hybrid similarity score with vector cosine similarity or reranking model similarity. By adjusting this weight, you can control the influence of keyword similarity in relation to other similarity measures. The default value is `0.7`. + - `"top_n"`: `int` This argument specifies the number of top chunks with similarity scores above the `similarity_threshold` that are fed to the LLM. The LLM will *only* access these 'top N' chunks. The default value is `8`. + - `"variables"`: `list[dict[]]` This argument lists the variables to use in the 'System' field of **Chat Configurations**. Note that: + - `knowledge` is a reserved variable, which represents the retrieved chunks. + - All the variables in 'System' should be curly bracketed. + - The default value is `[{"key": "knowledge", "optional": True}]`. + - `"rerank_model"`: `str` If it is not specified, vector cosine similarity will be used; otherwise, reranking score will be used. Defaults to `""`. + - `"empty_response"`: `str` If nothing is retrieved in the dataset for the user's question, this will be used as the response. To allow the LLM to improvise when nothing is retrieved, leave this blank. Defaults to `None`. + - `"opener"`: `str` The opening greeting for the user. Defaults to `"Hi! I am your assistant, can I help you?"`. + - `"show_quote`: `bool` Indicates whether the source of text should be displayed Defaults to `True`. + - `"prompt"`: `str` The prompt content. #### Returns @@ -1271,7 +1277,7 @@ rag_object = RAGFlow(api_key="", base_url="http://: datasets = rag_object.list_datasets(name="kb_1") dataset_id = datasets[0].id assistant = rag_object.create_chat("Miss R", dataset_ids=[dataset_id]) -assistant.update({"name": "Stefan", "llm_setting": {"temperature": 0.8}, "top_n": 8}) +assistant.update({"name": "Stefan", "llm": {"temperature": 0.8}, "prompt": {"top_n": 8}}) ``` --- @@ -1322,11 +1328,8 @@ RAGFlow.list_chats( page_size: int = 30, orderby: str = "create_time", desc: bool = True, - id: str | None = None, - name: str | None = None, - keywords: str | None = None, - owner_ids: str | list[str] | None = None, - parser_id: str | None = None + id: str = None, + name: str = None ) -> list[Chat] ``` @@ -1353,27 +1356,13 @@ The attribute by which the results are sorted. Available options: Indicates whether the retrieved chat assistants should be sorted in descending order. Defaults to `True`. -##### id: `str | None` +##### id: `str` -Exact match on chat assistant ID. Defaults to `None`. +The ID of the chat assistant to retrieve. Defaults to `None`. -##### name: `str | None` +##### name: `str` -Exact match on chat assistant name. Defaults to `None`. - -##### keywords: `str | None` - -Case-insensitive fuzzy match against chat assistant names. Defaults to `None`. - -##### owner_ids: `str | list[str] | None` - -Filter by owner tenant IDs. Defaults to `None`. - -##### parser_id: `str | None` - -Filter by parser type. Defaults to `None`. - -When `id` or `name` is provided, exact filtering takes precedence over `keywords`. +The name of the chat assistant to retrieve. Defaults to `None`. #### Returns diff --git a/sdk/python/ragflow_sdk/modules/base.py b/sdk/python/ragflow_sdk/modules/base.py index f6c77899e..6b958fb8d 100644 --- a/sdk/python/ragflow_sdk/modules/base.py +++ b/sdk/python/ragflow_sdk/modules/base.py @@ -54,9 +54,5 @@ class Base: res = self.rag.put(path, json) return res - def patch(self, path, json): - res = self.rag.patch(path, json) - return res - def __str__(self): return str(self.to_json()) diff --git a/sdk/python/ragflow_sdk/modules/chat.py b/sdk/python/ragflow_sdk/modules/chat.py index 82374b2d5..7b5c94725 100644 --- a/sdk/python/ragflow_sdk/modules/chat.py +++ b/sdk/python/ragflow_sdk/modules/chat.py @@ -23,22 +23,50 @@ class Chat(Base): def __init__(self, rag, res_dict): self.id = "" self.name = "assistant" - self.icon = "" - self.dataset_ids = [] - self.llm_id = None - self.llm_setting = {} - self.prompt_config = {} - self.similarity_threshold = 0.2 - self.vector_similarity_weight = 0.3 - self.top_n = 6 - self.top_k = 1024 - self.rerank_id = "" + self.avatar = "path/to/avatar" + self.llm = Chat.LLM(rag, {}) + self.prompt = Chat.Prompt(rag, {}) super().__init__(rag, res_dict) + class LLM(Base): + def __init__(self, rag, res_dict): + self.model_name = None + self.temperature = 0.1 + self.top_p = 0.3 + self.presence_penalty = 0.4 + self.frequency_penalty = 0.7 + self.max_tokens = 512 + super().__init__(rag, res_dict) + + class Prompt(Base): + def __init__(self, rag, res_dict): + self.similarity_threshold = 0.2 + self.keywords_similarity_weight = 0.7 + self.top_n = 8 + self.top_k = 1024 + self.variables = [{"key": "knowledge", "optional": True}] + self.rerank_model = "" + self.empty_response = None + self.opener = "Hi! I'm your assistant. What can I do for you?" + self.show_quote = True + self.prompt = ( + "You are an intelligent assistant. Your primary function is to answer questions based strictly on the provided knowledge base." + "**Essential Rules:**" + "- Your answer must be derived **solely** from this knowledge base: `{knowledge}`." + "- **When information is available**: Summarize the content to give a detailed answer." + "- **When information is unavailable**: Your response must contain this exact sentence: 'The answer you are looking for is not found in the knowledge base!' " + "- **Always consider** the entire conversation history." + ) + super().__init__(rag, res_dict) + def update(self, update_message: dict): if not isinstance(update_message, dict): raise Exception("ValueError('`update_message` must be a dict')") - res = self.patch(f"/chats/{self.id}", update_message) + if update_message.get("llm") == {}: + raise Exception("ValueError('`llm` cannot be empty')") + if update_message.get("prompt") == {}: + raise Exception("ValueError('`prompt` cannot be empty')") + res = self.put(f"/chats/{self.id}", update_message) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index e60a4eeab..15b571872 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -49,10 +49,6 @@ class RAGFlow: res = requests.put(url=self.api_url + path, json=json, headers=self.authorization_header) return res - def patch(self, path, json): - res = requests.patch(url=self.api_url + path, json=json, headers=self.authorization_header) - return res - def create_dataset( self, name: str, @@ -115,25 +111,55 @@ class RAGFlow: return result_list raise Exception(res["message"]) - def create_chat( - self, - name: str, - icon: str = "", - dataset_ids: list[str] | None = None, - llm_id: str | None = None, - llm_setting: dict | None = None, - prompt_config: dict | None = None, - **kwargs, - ) -> Chat: - payload = {"name": name, "icon": icon, "dataset_ids": dataset_ids or []} - if llm_id is not None: - payload["llm_id"] = llm_id - if llm_setting is not None: - payload["llm_setting"] = llm_setting - if prompt_config is not None: - payload["prompt_config"] = prompt_config - payload.update(kwargs) - res = self.post("/chats", payload) + def create_chat(self, name: str, avatar: str = "", dataset_ids=None, llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat: + if dataset_ids is None: + dataset_ids = [] + dataset_list = [] + for id in dataset_ids: + dataset_list.append(id) + + if llm is None: + llm = Chat.LLM( + self, + { + "model_name": None, + "temperature": 0.1, + "top_p": 0.3, + "presence_penalty": 0.4, + "frequency_penalty": 0.7, + "max_tokens": 512, + }, + ) + if prompt is None: + prompt = Chat.Prompt( + self, + { + "similarity_threshold": 0.2, + "keywords_similarity_weight": 0.7, + "top_n": 8, + "top_k": 1024, + "variables": [{"key": "knowledge", "optional": True}], + "rerank_model": "", + "empty_response": None, + "opener": None, + "show_quote": True, + "prompt": None, + }, + ) + if prompt.opener is None: + prompt.opener = "Hi! I'm your assistant. What can I do for you?" + if prompt.prompt is None: + prompt.prompt = ( + "You are an intelligent assistant. Your primary function is to answer questions based strictly on the provided knowledge base." + "**Essential Rules:**" + "- Your answer must be derived **solely** from this knowledge base: `{knowledge}`." + "- **When information is available**: Summarize the content to give a detailed answer." + "- **When information is unavailable**: Your response must contain this exact sentence: 'The answer you are looking for is not found in the knowledge base!' " + "- **Always consider** the entire conversation history." + ) + + temp_dict = {"name": name, "avatar": avatar, "dataset_ids": dataset_list if dataset_list else [], "llm": llm.to_json(), "prompt": prompt.to_json()} + res = self.post("/chats", temp_dict) res = res.json() if res.get("code") == 0: return Chat(self, res["data"]) @@ -145,24 +171,7 @@ class RAGFlow: if res.get("code") != 0: raise Exception(res["message"]) - def get_chat(self, chat_id: str) -> Chat: - res = self.get(f"/chats/{chat_id}") - res = res.json() - if res.get("code") == 0: - return Chat(self, res["data"]) - raise Exception(res["message"]) - - def list_chats( - self, - page: int = 1, - page_size: int = 30, - orderby: str = "create_time", - desc: bool = True, - id: str | None = None, - name: str | None = None, - keywords: str | None = None, - owner_ids: str | list[str] | None = None, - ) -> list[Chat]: + def list_chats(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[Chat]: res = self.get( "/chats", { @@ -172,14 +181,12 @@ class RAGFlow: "desc": desc, "id": id, "name": name, - "keywords": keywords, - "owner_ids": owner_ids, }, ) res = res.json() result_list = [] if res.get("code") == 0: - for data in res["data"]["chats"]: + for data in res["data"]: result_list.append(Chat(self, data)) return result_list raise Exception(res["message"]) diff --git a/test/benchmark/README.md b/test/benchmark/README.md index 031d92d5b..847ae457b 100644 --- a/test/benchmark/README.md +++ b/test/benchmark/README.md @@ -151,7 +151,7 @@ Model selection guidance - Chat model is tied to the chat assistant. Set during chat creation using --chat-payload: ``` - {"name": "...", "llm_id": "@", "llm_setting": {}} + {"name": "...", "llm": {"model_name": "@"}} ``` Or set tenant defaults via --set-tenant-info with --tenant-llm-id. - --model is required by the OpenAI-compatible endpoint but does not override @@ -190,7 +190,7 @@ Example: chat benchmark creating dataset + upload + parse + chat (login + regist --document-path test/benchmark/test_docs/Doc2.pdf \ --document-path test/benchmark/test_docs/Doc3.pdf \ --chat-name "bench_chat" \ - --chat-payload '{"name":"bench_chat","llm_id":"glm-4-flash@ZHIPU-AI","llm_setting":{}}' \ + --chat-payload '{"name":"bench_chat","llm":{"model_name":"glm-4-flash@ZHIPU-AI"}}' \ --message "What is the purpose of RAGFlow?" \ --model "glm-4-flash@ZHIPU-AI" ``` diff --git a/test/benchmark/chat.py b/test/benchmark/chat.py index cfff29c7b..52146314c 100644 --- a/test/benchmark/chat.py +++ b/test/benchmark/chat.py @@ -26,8 +26,8 @@ def create_chat( body = dict(payload or {}) if "name" not in body: body["name"] = name - if dataset_ids is not None and "kb_ids" not in body: - body["kb_ids"] = dataset_ids + if dataset_ids is not None and "dataset_ids" not in body: + body["dataset_ids"] = dataset_ids res = client.request_json("POST", "/chats", json_body=body) if res.get("code") != 0: raise ChatError(f"Create chat failed: {res.get('message')}") @@ -35,23 +35,24 @@ def create_chat( def get_chat(client: HttpClient, chat_id: str) -> Dict[str, Any]: - res = client.request_json("GET", f"/chats/{chat_id}") + res = client.request_json("GET", "/chats", params={"id": chat_id}) if res.get("code") != 0: raise ChatError(f"Get chat failed: {res.get('message')}") - data = res.get("data", {}) + data = res.get("data", []) if not data: raise ChatError("Chat not found") - return data + return data[0] def resolve_model(model: Optional[str], chat_data: Optional[Dict[str, Any]]) -> str: if model: return model if chat_data: - llm_id = chat_data.get("llm_id") - if llm_id: - return llm_id - raise ChatError("Model name is required; provide --model or use a chat with llm_id.") + llm = chat_data.get("llm") or {} + llm_name = llm.get("model_name") + if llm_name: + return llm_name + raise ChatError("Model name is required; provide --model or use a chat with llm.model_name.") def _parse_stream_error(response) -> Optional[str]: diff --git a/test/benchmark/run_chat.sh b/test/benchmark/run_chat.sh index 4ca7fe15d..54c232748 100755 --- a/test/benchmark/run_chat.sh +++ b/test/benchmark/run_chat.sh @@ -20,7 +20,7 @@ PYTHONPATH="${REPO_ROOT}/test" uv run -m benchmark chat \ --document-path "${SCRIPT_DIR}/test_docs/Doc2.pdf" \ --document-path "${SCRIPT_DIR}/test_docs/Doc3.pdf" \ --chat-name "bench_chat" \ - --chat-payload '{"name":"bench_chat","llm_id":"glm-4-flash@ZHIPU-AI","llm_setting":{}}' \ + --chat-payload '{"name":"bench_chat","llm":{"model_name":"glm-4-flash@ZHIPU-AI"}}' \ --message "What is the purpose of RAGFlow?" \ --model "glm-4-flash@ZHIPU-AI" \ --iterations 10 \ diff --git a/test/benchmark/run_retrieval_chat.sh b/test/benchmark/run_retrieval_chat.sh index cb5d264d2..9cd531803 100755 --- a/test/benchmark/run_retrieval_chat.sh +++ b/test/benchmark/run_retrieval_chat.sh @@ -10,7 +10,7 @@ BASE_URL="http://127.0.0.1:9380" LOGIN_EMAIL="qa@infiniflow.org" LOGIN_PASSWORD="123" DATASET_PAYLOAD='{"name":"bench_dataset","embedding_model":"BAAI/bge-small-en-v1.5@Builtin"}' -CHAT_PAYLOAD='{"name":"bench_chat","llm_id":"glm-4-flash@ZHIPU-AI","llm_setting":{}}' +CHAT_PAYLOAD='{"name":"bench_chat","llm":{"model_name":"glm-4-flash@ZHIPU-AI"}}' DATASET_ID="" cleanup_dataset() { diff --git a/test/playwright/e2e/test_next_apps_chat.py b/test/playwright/e2e/test_next_apps_chat.py index 135b10af2..fa617bbc0 100644 --- a/test/playwright/e2e/test_next_apps_chat.py +++ b/test/playwright/e2e/test_next_apps_chat.py @@ -172,7 +172,7 @@ def _mm_open_and_close_embed_dialog_if_available(page) -> bool: def _mm_settings_save_request(req) -> bool: - return req.method.upper() in MM_REQUEST_METHOD_WHITELIST and "/api/v1/chats" in req.url + return req.method.upper() in MM_REQUEST_METHOD_WHITELIST and "/dialog/set" in req.url def _mm_open_settings_panel(page): @@ -559,11 +559,9 @@ def mm_step_07_settings_open_close_cancel_save(ctx: FlowContext, step, snap): with page.expect_request(_mm_settings_save_request, timeout=RESULT_TIMEOUT_MS) as req_info: page.get_by_test_id("chat-settings-save").click() payload = _mm_payload_from_request(req_info.value) - assert payload.get("name"), "missing name in /api/v1/chats payload" - assert "kb_ids" in payload, "missing kb_ids in /api/v1/chats payload" - assert payload.get("llm_id"), "missing llm_id in /api/v1/chats payload" - assert "llm_setting" in payload, "missing llm_setting in /api/v1/chats payload" - assert "prompt_config" in payload, "missing prompt_config in /api/v1/chats payload" + assert payload.get("dialog_id"), "missing dialog_id in /dialog/set payload" + assert "llm_id" in payload, "missing llm_id in /dialog/set payload" + assert "llm_setting" in payload, "missing llm_setting in /dialog/set payload" ctx.state["mm_settings_saved"] = True snap("chat_mm_settings_saved") @@ -661,7 +659,8 @@ def mm_step_11_apply_multimodel_config(ctx: FlowContext, step, snap): with page.expect_request(_mm_settings_save_request, timeout=RESULT_TIMEOUT_MS) as req_info: apply_btn.click() payload = _mm_payload_from_request(req_info.value) - assert payload.get("llm_id"), "missing llm_id in apply-config payload" + assert payload.get("dialog_id"), "missing dialog_id in apply-config payload" + assert "llm_id" in payload, "missing llm_id in apply-config payload" assert "llm_setting" in payload, "missing llm_setting in apply-config payload" ctx.state["mm_cards_configured"] = True diff --git a/test/playwright/helpers/_next_apps_helpers.py b/test/playwright/helpers/_next_apps_helpers.py index 330002275..f3607feb2 100644 --- a/test/playwright/helpers/_next_apps_helpers.py +++ b/test/playwright/helpers/_next_apps_helpers.py @@ -379,7 +379,7 @@ def _select_first_dataset_and_save( return isinstance(kb_ids, list) and len(kb_ids) > 0 response_url_pattern = ( - "/api/v1/chats" if save_testid == "chat-settings-save" else "/api/v1/searches/" + "/dialog/set" if save_testid == "chat-settings-save" else "/api/v1/searches/" ) last_payload = {} last_combobox_text = "" diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index 1e659291e..2678879f9 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -216,24 +216,12 @@ def list_chat_assistants(auth, params=None): return res.json() -def get_chat_assistant(auth, chat_assistant_id): - url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}" - res = requests.get(url=url, headers=HEADERS, auth=auth) - return res.json() - - def update_chat_assistant(auth, chat_assistant_id, payload=None): url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}" res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() -def patch_chat_assistant(auth, chat_assistant_id, payload=None): - url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}" - res = requests.patch(url=url, headers=HEADERS, auth=auth, json=payload) - return res.json() - - def delete_chat_assistants(auth, payload=None): url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) diff --git a/test/testcases/test_http_api/test_chat_assistant_management/conftest.py b/test/testcases/test_http_api/test_chat_assistant_management/conftest.py index 330732db6..b81b48edc 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/conftest.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/conftest.py @@ -14,7 +14,7 @@ # limitations under the License. # import pytest -from common import batch_create_chat_assistants, delete_all_chat_assistants, get_chat_assistant, list_documents, parse_documents +from common import batch_create_chat_assistants, delete_all_chat_assistants, list_chat_assistants, list_documents, parse_documents from utils import wait_for @@ -43,7 +43,7 @@ def add_chat_assistants_func(request, HttpApiAuth, add_document): @pytest.fixture(scope="function") def chat_assistant_llm_model_type(HttpApiAuth, add_chat_assistants_func): _, _, chat_assistant_ids = add_chat_assistants_func - res = get_chat_assistant(HttpApiAuth, chat_assistant_ids[0]) + res = list_chat_assistants(HttpApiAuth, {"id": chat_assistant_ids[0]}) if res.get("code") == 0 and res.get("data"): - return res["data"].get("llm_setting", {}).get("model_type", "chat") + return res["data"][0].get("llm", {}).get("model_type", "chat") return "chat" diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py index ea3573a23..5ca56b925 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py @@ -44,40 +44,38 @@ class _AwaitableValue: class _DummyKB: - def __init__(self, kid="kb-1", embd_id="embd@factory", chunk_num=1, name="Dataset A", status="1"): - self.id = kid + def __init__(self, embd_id="embd@factory", chunk_num=1, tenant_embd_id=1): self.embd_id = embd_id self.chunk_num = chunk_num - self.name = name - self.status = status + self.tenant_embd_id = tenant_embd_id + + def to_json(self): + return {"id": "kb-1"} class _DummyDialogRecord: - def __init__(self, data=None): - self._data = data or { + def __init__(self): + self._data = { "id": "chat-1", "name": "chat-name", - "description": "desc", - "icon": "icon.png", - "kb_ids": ["kb-1"], - "llm_id": "glm-4", - "llm_setting": {"temperature": 0.1}, "prompt_config": { "system": "Answer with {knowledge}", "parameters": [{"key": "knowledge", "optional": False}], "prologue": "hello", "quote": True, }, + "llm_setting": {"temperature": 0.1}, + "llm_id": "glm-4", "similarity_threshold": 0.2, "vector_similarity_weight": 0.3, "top_n": 6, - "top_k": 1024, "rerank_id": "", - "meta_data_filter": {}, - "tenant_id": "tenant-1", + "top_k": 1024, + "kb_ids": ["kb-1"], + "icon": "icon.png", } - def to_dict(self): + def to_json(self): return deepcopy(self._data) @@ -87,15 +85,47 @@ def _run(coro): def _load_chat_module(monkeypatch): repo_root = Path(__file__).resolve().parents[4] - module_name = "test_chat_restful_routes_unit_module" - module_path = repo_root / "api" / "apps" / "restful_apis" / "chat_api.py" + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + deepdoc_pkg = ModuleType("deepdoc") + deepdoc_parser_pkg = ModuleType("deepdoc.parser") + deepdoc_parser_pkg.__path__ = [] + + class _StubPdfParser: + pass + + class _StubExcelParser: + pass + + class _StubDocxParser: + pass + + deepdoc_parser_pkg.PdfParser = _StubPdfParser + deepdoc_parser_pkg.ExcelParser = _StubExcelParser + deepdoc_parser_pkg.DocxParser = _StubDocxParser + deepdoc_pkg.parser = deepdoc_parser_pkg + monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg) + monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg) + + deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser") + deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser + monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module) + + deepdoc_parser_utils = ModuleType("deepdoc.parser.utils") + deepdoc_parser_utils.get_text = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils) + monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost")) + + module_name = "test_chat_sdk_routes_unit_module" + module_path = repo_root / "api" / "apps" / "sdk" / "chat.py" spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() monkeypatch.setitem(sys.modules, module_name, module) spec.loader.exec_module(module) - monkeypatch.setattr(module, "current_user", SimpleNamespace(id="tenant-1")) return module @@ -104,357 +134,227 @@ def _set_request_json(monkeypatch, module, payload): @pytest.mark.p2 -def test_create_chat_uses_direct_chat_fields(monkeypatch): +def test_create_internal_failure_paths(monkeypatch): module = _load_chat_module(monkeypatch) - saved = {} - _set_request_json( - monkeypatch, - module, - { - "name": "chat-a", - "icon": "icon.png", - "dataset_ids": ["kb-1"], - "llm_id": "glm-4", - "llm_setting": {"temperature": 0.8}, - "prompt_config": { - "system": "Answer with {knowledge}", - "parameters": [{"key": "knowledge", "optional": False}], - "prologue": "Hi", - }, - "vector_similarity_weight": 0.25, - }, - ) - monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) + _set_request_json(monkeypatch, module, {"name": "chat-a", "dataset_ids": ["kb-1", "kb-2"]}) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb")]) + monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB(chunk_num=1)]) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [_DummyKB(embd_id="embd-a@x"), _DummyKB(embd_id="embd-b@y")]) monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) - monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: [SimpleNamespace(id="llm-1")]) + res = _run(module.create.__wrapped__("tenant-1")) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "different embedding models" in res["message"] - def _save(**kwargs): - saved.update(kwargs) - return True + _set_request_json(monkeypatch, module, {"name": "chat-a", "dataset_ids": []}) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (False, None)) + res = _run(module.create.__wrapped__("tenant-1")) + assert res["message"] == "Tenant not found!" - monkeypatch.setattr(module.DialogService, "save", _save) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved))) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module.DialogService, "save", lambda **_kwargs: False) + res = _run(module.create.__wrapped__("tenant-1")) + assert res["message"] == "Fail to new a chat!" - res = _run(module.create.__wrapped__()) - - assert res["code"] == 0 - assert saved["kb_ids"] == ["kb-1"] - assert saved["prompt_config"]["prologue"] == "Hi" - assert saved["llm_id"] == "glm-4" - assert saved["llm_setting"]["temperature"] == 0.8 - assert res["data"]["dataset_ids"] == ["kb-1"] - assert res["data"]["kb_names"] == ["Dataset A"] - assert "kb_ids" not in res["data"] - assert "prompt" not in res["data"] - assert "llm" not in res["data"] - assert "avatar" not in res["data"] - - -@pytest.mark.p1 -def test_create_chat_accepts_provider_scoped_rerank_id(monkeypatch): - module = _load_chat_module(monkeypatch) - saved = {} - query_calls = [] + monkeypatch.setattr(module.DialogService, "save", lambda **_kwargs: True) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None)) + res = _run(module.create.__wrapped__("tenant-1")) + assert res["message"] == "Fail to new a chat!" _set_request_json( monkeypatch, module, - { - "name": "chat-a", - "icon": "icon.png", - "dataset_ids": ["kb-1"], - "llm_id": "glm-4@ZHIPU-AI", - "llm_setting": {"temperature": 0.8}, - "prompt_config": { - "system": "Answer with {knowledge}", - "parameters": [{"key": "knowledge", "optional": False}], - "prologue": "Hi", - }, - "rerank_id": "custom-reranker@OpenAI", - "vector_similarity_weight": 0.25, - }, - ) - monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4@ZHIPU-AI"))) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) - - def _split_model_name_and_factory(model_name): - return { - "glm-4@ZHIPU-AI": ("glm-4", "ZHIPU-AI"), - "custom-reranker@OpenAI": ("custom-reranker", "OpenAI"), - }.get(model_name, (model_name, None)) - - def _query(**kwargs): - query_calls.append(kwargs) - if kwargs == { - "tenant_id": "tenant-1", - "llm_name": "glm-4", - "llm_factory": "ZHIPU-AI", - "model_type": "chat", - }: - return [SimpleNamespace(id="llm-1")] - if kwargs == { - "tenant_id": "tenant-1", - "llm_name": "custom-reranker", - "llm_factory": "OpenAI", - "model_type": "rerank", - }: - return [SimpleNamespace(id="rerank-1")] - return [] - - monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", _split_model_name_and_factory) - monkeypatch.setattr(module.TenantLLMService, "query", _query) - - def _save(**kwargs): - saved.update(kwargs) - return True - - monkeypatch.setattr(module.DialogService, "save", _save) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved))) - - res = _run(module.create.__wrapped__()) - - assert res["code"] == 0 - assert saved["rerank_id"] == "custom-reranker@OpenAI" - assert { - "tenant_id": "tenant-1", - "llm_name": "custom-reranker", - "llm_factory": "OpenAI", - "model_type": "rerank", - } in query_calls - - -@pytest.mark.p1 -def test_create_chat_allows_default_knowledge_placeholder_without_sources(monkeypatch): - module = _load_chat_module(monkeypatch) - saved = {} - - _set_request_json(monkeypatch, module, {"name": "chat-a"}) - monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda *_args, **_kwargs: SimpleNamespace(id=1)) - - def _save(**kwargs): - saved.update(kwargs) - return True - - monkeypatch.setattr(module.DialogService, "save", _save) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved))) - - res = _run(module.create.__wrapped__()) - - assert res["code"] == 0 - assert saved["kb_ids"] == [] - assert saved["prompt_config"]["system"].find("{knowledge}") >= 0 - assert saved["prompt_config"]["parameters"] == [{"key": "knowledge", "optional": False}] - - -@pytest.mark.p1 -def test_create_chat_uses_tenant_default_llm_when_llm_id_is_null(monkeypatch): - module = _load_chat_module(monkeypatch) - saved = {} - - _set_request_json( - monkeypatch, - module, - { - "name": "chat-a", - "dataset_ids": ["kb-1"], - "llm_id": None, - "llm_setting": {"temperature": 0.8}, - "prompt_config": { - "system": "Answer with {knowledge}", - "parameters": [{"key": "knowledge", "optional": False}], - }, - }, + {"name": "chat-rerank", "dataset_ids": [], "prompt": {"rerank_model": "unknown-rerank-model"}}, ) monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) - monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda *_args, **_kwargs: SimpleNamespace(id=1)) + rerank_query_calls = [] - def _save(**kwargs): - saved.update(kwargs) - return True + def _mock_tenant_llm_query(**kwargs): + rerank_query_calls.append(kwargs) + return False - monkeypatch.setattr(module.DialogService, "save", _save) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved))) + monkeypatch.setattr(module.TenantLLMService, "query", _mock_tenant_llm_query) + res = _run(module.create.__wrapped__("tenant-1")) + assert "`rerank_model` unknown-rerank-model doesn't exist" in res["message"] + assert rerank_query_calls[-1]["model_type"] == "rerank" + assert rerank_query_calls[-1]["llm_name"] == "unknown-rerank-model" - res = _run(module.create.__wrapped__()) - - assert res["code"] == 0 - assert saved["llm_id"] == "glm-4" - assert saved["llm_setting"]["temperature"] == 0.8 + _set_request_json(monkeypatch, module, {"name": "chat-tenant", "dataset_ids": [], "tenant_id": "tenant-forbidden"}) + res = _run(module.create.__wrapped__("tenant-1")) + assert res["message"] == "`tenant_id` must not be provided." @pytest.mark.p2 -def test_patch_chat_merges_prompt_and_llm_settings(monkeypatch): +def test_update_internal_failure_paths(monkeypatch): module = _load_chat_module(monkeypatch) - updated = {} - existing = _DummyDialogRecord().to_dict() - _set_request_json( - monkeypatch, - module, - { - "prompt_config": {"prologue": "updated opener"}, - "llm_setting": {"temperature": 0.9}, - }, - ) + _set_request_json(monkeypatch, module, {"name": "anything"}) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert res["message"] == "You do not own the chat" + + _set_request_json(monkeypatch, module, {"name": "chat-name"}) monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing))) - monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (False, None)) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert res["message"] == "Tenant not found!" - def _update(_chat_id, payload): - updated.update(payload) - return True + _set_request_json(monkeypatch, module, {"dataset_ids": ["kb-1", "kb-2"]}) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(id="tenant-1"))) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb")]) + monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB(chunk_num=1)]) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [_DummyKB(embd_id="embd-a@x"), _DummyKB(embd_id="embd-b@y")]) + monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "different embedding models" in res["message"] - monkeypatch.setattr(module.DialogService, "update_by_id", _update) + _set_request_json(monkeypatch, module, {"avatar": "new-avatar"}) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord())) + monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: False) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert res["message"] == "Chat not found!" - res = _run(module.patch_chat.__wrapped__("chat-1")) - - assert res["code"] == 0 - assert updated["prompt_config"]["system"] == "Answer with {knowledge}" - assert updated["prompt_config"]["prologue"] == "updated opener" - assert updated["llm_setting"]["temperature"] == 0.9 - - -@pytest.mark.p2 -def test_patch_chat_drops_response_only_fields_before_update(monkeypatch): - module = _load_chat_module(monkeypatch) - updated = {} - existing = _DummyDialogRecord().to_dict() - payload = { - "name": "renamed-chat", - "description": existing["description"], - "icon": existing["icon"], - "dataset_ids": existing["kb_ids"], - "kb_names": ["Dataset A"], - "llm_id": existing["llm_id"], - "llm_setting": existing["llm_setting"], - "prompt_config": existing["prompt_config"], - "similarity_threshold": existing["similarity_threshold"], - "vector_similarity_weight": existing["vector_similarity_weight"], - "top_n": existing["top_n"], - "top_k": existing["top_k"], - "rerank_id": existing["rerank_id"], - } - - _set_request_json(monkeypatch, module, payload) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(id="tenant-1"))) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord())) + monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: True) monkeypatch.setattr( module.DialogService, "query", - lambda **kwargs: [] if "name" in kwargs else [SimpleNamespace(id="chat-1")], + lambda **kwargs: ( + [SimpleNamespace(id="chat-1")] + if kwargs.get("id") == "chat-1" + else ([SimpleNamespace(id="dup")] if kwargs.get("name") == "dup-name" else []) + ), + ) + monkeypatch.setattr( + module.TenantLLMService, + "split_model_name_and_factory", + lambda model: (model.split("@")[0], "factory"), + ) + monkeypatch.setattr( + module.TenantLLMService, + "query", + lambda **kwargs: kwargs.get("llm_name") in {"glm-4", "allowed-rerank"}, ) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing))) - monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) - monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) - monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: [SimpleNamespace(id="llm-1")]) - - def _update(_chat_id, req): - updated.update(req) - return True - - monkeypatch.setattr(module.DialogService, "update_by_id", _update) - - res = _run(module.patch_chat.__wrapped__("chat-1")) + _set_request_json(monkeypatch, module, {"show_quotation": True}) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) assert res["code"] == 0 - assert updated["name"] == "renamed-chat" - assert "kb_names" not in updated + _set_request_json(monkeypatch, module, {"dataset_ids": ["kb-no-owner"]}) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: []) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert "You don't own the dataset kb-no-owner" in res["message"] -@pytest.mark.p2 -def test_update_chat_rejects_knowledge_placeholder_without_sources(monkeypatch): - module = _load_chat_module(monkeypatch) - existing = _DummyDialogRecord().to_dict() + _set_request_json(monkeypatch, module, {"dataset_ids": ["kb-unparsed"]}) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-unparsed")]) + monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB(chunk_num=0)]) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert "doesn't own parsed file" in res["message"] + + _set_request_json(monkeypatch, module, {"llm": {"model_name": "unknown-model", "model_type": "unsupported"}}) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert "`model_name` unknown-model doesn't exist" in res["message"] _set_request_json( monkeypatch, module, - { - "name": "chat-name", - "description": "desc", - "icon": "icon.png", - "dataset_ids": [], - "llm_id": "glm-4", - "llm_setting": {"temperature": 0.1}, - "prompt_config": { - "system": "Answer with {knowledge}", - "parameters": [{"key": "knowledge", "optional": False}], - "prologue": "hello", - "quote": True, - }, - "similarity_threshold": 0.2, - "vector_similarity_weight": 0.3, - "top_n": 6, - "top_k": 1024, - "rerank_id": "", - }, + {"prompt": {"prompt": "No placeholder", "variables": [{"key": "knowledge", "optional": False}], "rerank_model": "unknown-rerank"}}, ) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing))) - monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) - monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) - monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: [SimpleNamespace(id="llm-1")]) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert "`rerank_model` unknown-rerank doesn't exist" in res["message"] - res = _run(module.update_chat.__wrapped__("chat-1")) + _set_request_json( + monkeypatch, + module, + {"prompt": {"prompt": "No placeholder", "variables": [{"key": "knowledge", "optional": False}]}}, + ) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert "Parameter 'knowledge' is not used" in res["message"] - assert res["code"] == 102 - assert res["message"] == "Please remove `{knowledge}` in system prompt since no dataset / Tavily used here." + _set_request_json( + monkeypatch, + module, + {"prompt": {"prompt": "Optional-only prompt", "variables": [{"key": "maybe", "optional": True}]}}, + ) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert res["code"] == 0 + + _set_request_json(monkeypatch, module, {"name": ""}) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert res["message"] == "`name` cannot be empty." + + _set_request_json(monkeypatch, module, {"name": "dup-name"}) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert res["message"] == "Duplicated chat name in updating chat." + + _set_request_json(monkeypatch, module, {"llm": {"model_name": "glm-4", "temperature": 0.9}}) + res = _run(module.update.__wrapped__("tenant-1", "chat-1")) + assert res["code"] == 0 @pytest.mark.p2 -def test_list_chats_returns_old_business_fields(monkeypatch): +def test_delete_duplicate_no_success_path(monkeypatch): module = _load_chat_module(monkeypatch) - monkeypatch.setattr( - module, - "request", - SimpleNamespace( - args=SimpleNamespace( - get=lambda key, default=None: { - "keywords": "", - "page": 1, - "page_size": 20, - "orderby": "create_time", - "desc": "true", - }.get(key, default), - getlist=lambda _key: [], - ) - ), - ) + + _set_request_json(monkeypatch, module, {}) monkeypatch.setattr( module.DialogService, - "get_by_tenant_ids", - lambda *_args, **_kwargs: ( - [_DummyDialogRecord().to_dict()], - 1, - ), + "query", + lambda **_kwargs: (_ for _ in ()).throw(AssertionError("query must not run for empty delete payload")), ) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) + res = _run(module.delete_chats.__wrapped__("tenant-1")) + assert res["code"] == module.RetCode.SUCCESS - res = module.list_chats.__wrapped__() + _set_request_json(monkeypatch, module, {"ids": ["chat-1", "chat-1"]}) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: 0) + res = _run(module.delete_chats.__wrapped__("tenant-1")) + assert res["code"] == module.RetCode.DATA_ERROR + assert "Duplicate assistant ids: chat-1" in res["message"] + + _set_request_json(monkeypatch, module, {"ids": ["missing-chat"]}) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + res = _run(module.delete_chats.__wrapped__("tenant-1")) + assert res["code"] == module.RetCode.DATA_ERROR + assert "Assistant(missing-chat) not found." in res["message"] + + _set_request_json(monkeypatch, module, {"ids": ["chat-1", "chat-1"]}) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: 1) + res = _run(module.delete_chats.__wrapped__("tenant-1")) + assert res["code"] == 0 + assert res["data"]["success_count"] == 1 + + +@pytest.mark.p2 +def test_list_missing_kb_warning_and_desc_false(monkeypatch, caplog): + module = _load_chat_module(monkeypatch) + + monkeypatch.setattr(module, "request", SimpleNamespace(args={"desc": "False"})) + monkeypatch.setattr(module.DialogService, "get_list", lambda *_args, **_kwargs: [ + { + "id": "chat-1", + "name": "chat-name", + "prompt_config": {"system": "Answer with {knowledge}", "parameters": [{"key": "knowledge", "optional": False}], "do_refer": True}, + "similarity_threshold": 0.2, + "vector_similarity_weight": 0.3, + "top_n": 6, + "rerank_id": "", + "llm_setting": {"temperature": 0.1}, + "llm_id": "glm-4", + "kb_ids": ["missing-kb"], + "icon": "icon.png", + } + ]) + monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: []) + + with caplog.at_level("WARNING"): + res = module.list_chat.__wrapped__("tenant-1") assert res["code"] == 0 - chat = res["data"]["chats"][0] - assert chat["icon"] == "icon.png" - assert chat["dataset_ids"] == ["kb-1"] - assert chat["kb_names"] == ["Dataset A"] - assert "kb_ids" not in chat - assert chat["prompt_config"]["prologue"] == "hello" - assert "dataset_names" not in chat - assert "prompt" not in chat - assert "llm" not in chat - - + assert res["data"][0]["datasets"] == [] + assert res["data"][0]["avatar"] == "icon.png" + assert "does not exist" in caplog.text diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py b/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py index a2877680f..172c66492 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py @@ -63,7 +63,7 @@ class TestChatAssistantsDelete: assert res["message"] == expected_message res = list_chat_assistants(HttpApiAuth) - assert len(res["data"]["chats"]) == remaining + assert len(res["data"]) == remaining @pytest.mark.parametrize( "payload", @@ -83,7 +83,7 @@ class TestChatAssistantsDelete: assert res["data"]["success_count"] == 5 res = list_chat_assistants(HttpApiAuth) - assert len(res["data"]["chats"]) == 0 + assert len(res["data"]) == 0 @pytest.mark.p3 def test_repeated_deletion(self, HttpApiAuth, add_chat_assistants_func): @@ -124,7 +124,7 @@ class TestChatAssistantsDelete: assert res["code"] == 0 res = list_chat_assistants(HttpApiAuth) - assert len(res["data"]["chats"]) == 0 + assert len(res["data"]) == 0 @pytest.mark.p2 def test_delete_all_errors_no_success_p2(self, HttpApiAuth, add_chat_assistants_func): diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py b/test/testcases/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py index 1fd4cf7eb..d9a4697a4 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py @@ -16,16 +16,12 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import delete_datasets, get_chat_assistant, list_chat_assistants +from common import delete_datasets, list_chat_assistants from configs import INVALID_API_TOKEN from libs.auth import RAGFlowHttpApiAuth from utils import is_sorted -def _chat_list(res): - return res["data"]["chats"] - - @pytest.mark.p1 class TestAuthorization: @pytest.mark.parametrize( @@ -51,8 +47,7 @@ class TestChatAssistantsList: def test_default(self, HttpApiAuth): res = list_chat_assistants(HttpApiAuth) assert res["code"] == 0 - assert len(_chat_list(res)) == 5 - assert res["data"]["total"] == 5 + assert len(res["data"]) == 5 @pytest.mark.p1 @pytest.mark.parametrize( @@ -83,7 +78,7 @@ class TestChatAssistantsList: res = list_chat_assistants(HttpApiAuth, params=params) assert res["code"] == expected_code if expected_code == 0: - assert len(_chat_list(res)) == expected_page_size + assert len(res["data"]) == expected_page_size else: assert res["message"] == expected_message @@ -123,7 +118,7 @@ class TestChatAssistantsList: res = list_chat_assistants(HttpApiAuth, params=params) assert res["code"] == expected_code if expected_code == 0: - assert len(_chat_list(res)) == expected_page_size + assert len(res["data"]) == expected_page_size else: assert res["message"] == expected_message @@ -131,13 +126,13 @@ class TestChatAssistantsList: @pytest.mark.parametrize( "params, expected_code, assertions, expected_message", [ - ({"orderby": None}, 0, lambda r: is_sorted(_chat_list(r), "create_time", True), ""), - ({"orderby": "create_time"}, 0, lambda r: is_sorted(_chat_list(r), "create_time", True), ""), - ({"orderby": "update_time"}, 0, lambda r: is_sorted(_chat_list(r), "update_time", True), ""), + ({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""), pytest.param( {"orderby": "name", "desc": "False"}, 0, - lambda r: is_sorted(_chat_list(r), "name", False), + lambda r: (is_sorted(r["data"], "name", False)), "", marks=pytest.mark.skip(reason="issues/5851"), ), @@ -170,14 +165,14 @@ class TestChatAssistantsList: @pytest.mark.parametrize( "params, expected_code, assertions, expected_message", [ - ({"desc": None}, 0, lambda r: is_sorted(_chat_list(r), "create_time", True), ""), - ({"desc": "true"}, 0, lambda r: is_sorted(_chat_list(r), "create_time", True), ""), - ({"desc": "True"}, 0, lambda r: is_sorted(_chat_list(r), "create_time", True), ""), - ({"desc": True}, 0, lambda r: is_sorted(_chat_list(r), "create_time", True), ""), - ({"desc": "false"}, 0, lambda r: is_sorted(_chat_list(r), "create_time", False), ""), - ({"desc": "False"}, 0, lambda r: is_sorted(_chat_list(r), "create_time", False), ""), - ({"desc": False}, 0, lambda r: is_sorted(_chat_list(r), "create_time", False), ""), - ({"desc": "False", "orderby": "update_time"}, 0, lambda r: is_sorted(_chat_list(r), "update_time", False), ""), + ({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), + ({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), + ({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), + ({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), + ({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""), pytest.param( {"desc": "unknown"}, 102, @@ -207,81 +202,90 @@ class TestChatAssistantsList: @pytest.mark.parametrize( "params, expected_code, expected_num, expected_message", [ - ({"keywords": None}, 0, 5, ""), - ({"keywords": ""}, 0, 5, ""), - ({"keywords": "test_chat_assistant_1"}, 0, 1, ""), - ({"keywords": "unknown"}, 0, 0, ""), + ({"name": None}, 0, 5, ""), + ({"name": ""}, 0, 5, ""), + ({"name": "test_chat_assistant_1"}, 0, 1, ""), + ({"name": "unknown"}, 102, 0, "The chat doesn't exist"), ], ) - def test_keywords(self, HttpApiAuth, params, expected_code, expected_num, expected_message): + def test_name(self, HttpApiAuth, params, expected_code, expected_num, expected_message): res = list_chat_assistants(HttpApiAuth, params=params) assert res["code"] == expected_code if expected_code == 0: - if params["keywords"] in [None, ""]: - assert len(_chat_list(res)) == expected_num + if params["name"] in [None, ""]: + assert len(res["data"]) == expected_num else: - assert len(_chat_list(res)) == expected_num - if expected_num: - assert _chat_list(res)[0]["name"] == params["keywords"] + assert res["data"][0]["name"] == params["name"] else: assert res["message"] == expected_message @pytest.mark.p1 @pytest.mark.parametrize( - "chat_assistant_id, expected_code, expected_message", + "chat_assistant_id, expected_code, expected_num, expected_message", [ - (lambda r: r[0], 0, ""), - ("unknown", 401, "No authorization."), + (None, 0, 5, ""), + ("", 0, 5, ""), + (lambda r: r[0], 0, 1, ""), + ("unknown", 102, 0, "The chat doesn't exist"), ], ) - def test_get_chat_assistant( + def test_id( self, HttpApiAuth, add_chat_assistants, chat_assistant_id, expected_code, - expected_message, - ): - _, _, chat_assistant_ids = add_chat_assistants - chat_id = chat_assistant_id(chat_assistant_ids) if callable(chat_assistant_id) else chat_assistant_id - res = get_chat_assistant(HttpApiAuth, chat_id) - assert res["code"] == expected_code - if expected_code == 0: - assert res["data"]["id"] == chat_id - else: - assert res["message"] == expected_message - - @pytest.mark.p3 - @pytest.mark.parametrize( - "chat_assistant_id, keywords, expected_code, expected_num, expected_message", - [ - (lambda r: r[0], "test_chat_assistant_0", 0, 1, ""), - (lambda r: r[0], "test_chat_assistant_1", 0, 0, ""), - (lambda r: r[0], "unknown", 0, 0, ""), - ], - ) - def test_get_and_keywords_are_separate_lookups( - self, - HttpApiAuth, - add_chat_assistants, - chat_assistant_id, - keywords, - expected_code, expected_num, expected_message, ): _, _, chat_assistant_ids = add_chat_assistants - chat_id = chat_assistant_id(chat_assistant_ids) if callable(chat_assistant_id) else chat_assistant_id - - get_res = get_chat_assistant(HttpApiAuth, chat_id) - list_res = list_chat_assistants(HttpApiAuth, params={"keywords": keywords}) - - assert get_res["code"] == expected_code - assert list_res["code"] == expected_code - if expected_code == 0: - assert len(_chat_list(list_res)) == expected_num + if callable(chat_assistant_id): + params = {"id": chat_assistant_id(chat_assistant_ids)} else: - assert get_res["message"] == expected_message + params = {"id": chat_assistant_id} + + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if params["id"] in [None, ""]: + assert len(res["data"]) == expected_num + else: + assert res["data"][0]["id"] == params["id"] + else: + assert res["message"] == expected_message + + @pytest.mark.p3 + @pytest.mark.parametrize( + "chat_assistant_id, name, expected_code, expected_num, expected_message", + [ + (lambda r: r[0], "test_chat_assistant_0", 0, 1, ""), + (lambda r: r[0], "test_chat_assistant_1", 102, 0, "The chat doesn't exist"), + (lambda r: r[0], "unknown", 102, 0, "The chat doesn't exist"), + ("id", "chat_assistant_0", 102, 0, "The chat doesn't exist"), + ], + ) + def test_name_and_id( + self, + HttpApiAuth, + add_chat_assistants, + chat_assistant_id, + name, + expected_code, + expected_num, + expected_message, + ): + _, _, chat_assistant_ids = add_chat_assistants + if callable(chat_assistant_id): + params = {"id": chat_assistant_id(chat_assistant_ids), "name": name} + else: + params = {"id": chat_assistant_id, "name": name} + + res = list_chat_assistants(HttpApiAuth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_num + else: + assert res["message"] == expected_message @pytest.mark.p3 def test_concurrent_list(self, HttpApiAuth): @@ -297,7 +301,7 @@ class TestChatAssistantsList: params = {"a": "b"} res = list_chat_assistants(HttpApiAuth, params=params) assert res["code"] == 0 - assert len(_chat_list(res)) == 5 + assert len(res["data"]) == 5 @pytest.mark.p2 def test_list_chats_after_deleting_associated_dataset(self, HttpApiAuth, add_chat_assistants): @@ -307,10 +311,10 @@ class TestChatAssistantsList: res = list_chat_assistants(HttpApiAuth) assert res["code"] == 0 - assert len(_chat_list(res)) == 5 + assert len(res["data"]) == 5 @pytest.mark.p2 def test_desc_false_parse_branch_p2(self, HttpApiAuth): res = list_chat_assistants(HttpApiAuth, params={"desc": "False", "orderby": "create_time"}) assert res["code"] == 0 - assert is_sorted(_chat_list(res), "create_time", False) + assert is_sorted(res["data"], "create_time", False) diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_update_chat_assistant.py b/test/testcases/test_http_api/test_chat_assistant_management/test_update_chat_assistant.py index 1670dd73a..dbd5d0161 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_update_chat_assistant.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_update_chat_assistant.py @@ -14,7 +14,7 @@ # limitations under the License. # import pytest -from common import create_chat_assistant, get_chat_assistant, patch_chat_assistant, update_chat_assistant +from common import create_chat_assistant, list_chat_assistants, update_chat_assistant from configs import CHAT_ASSISTANT_NAME_LIMIT, INVALID_API_TOKEN from libs.auth import RAGFlowHttpApiAuth from utils import encode_avatar @@ -48,18 +48,18 @@ class TestChatAssistantUpdate: pytest.param({"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")), pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")), pytest.param({"name": ""}, 102, "`name` cannot be empty.", marks=pytest.mark.p3), - pytest.param({"name": "test_chat_assistant_1"}, 102, "Duplicated chat name.", marks=pytest.mark.p3), - pytest.param({"name": "TEST_CHAT_ASSISTANT_1"}, 102, "Duplicated chat name.", marks=pytest.mark.p3), + pytest.param({"name": "test_chat_assistant_1"}, 102, "Duplicated chat name in updating chat.", marks=pytest.mark.p3), + pytest.param({"name": "TEST_CHAT_ASSISTANT_1"}, 102, "Duplicated chat name in updating chat.", marks=pytest.mark.p3), ], ) def test_name(self, HttpApiAuth, add_chat_assistants_func, payload, expected_code, expected_message): _, _, chat_assistant_ids = add_chat_assistants_func - res = patch_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) + res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) assert res["code"] == expected_code, res if expected_code == 0: - res = get_chat_assistant(HttpApiAuth, chat_assistant_ids[0]) - assert res["data"]["name"] == payload.get("name") + res = list_chat_assistants(HttpApiAuth, {"id": chat_assistant_ids[0]}) + assert res["data"][0]["name"] == payload.get("name") else: assert res["message"] == expected_message @@ -69,7 +69,7 @@ class TestChatAssistantUpdate: pytest.param([], 0, "", marks=pytest.mark.skip(reason="issues/")), pytest.param(lambda r: [r], 0, "", marks=pytest.mark.p1), pytest.param(["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id", marks=pytest.mark.p3), - pytest.param("invalid_dataset_id", 102, "`dataset_ids` should be a list.", marks=pytest.mark.p3), + pytest.param("invalid_dataset_id", 102, "You don't own the dataset i", marks=pytest.mark.p3), ], ) def test_dataset_ids(self, HttpApiAuth, add_chat_assistants_func, dataset_ids, expected_code, expected_message): @@ -83,8 +83,8 @@ class TestChatAssistantUpdate: res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) assert res["code"] == expected_code, res if expected_code == 0: - res = get_chat_assistant(HttpApiAuth, chat_assistant_ids[0]) - assert res["data"]["name"] == payload.get("name") + res = list_chat_assistants(HttpApiAuth, {"id": chat_assistant_ids[0]}) + assert res["data"][0]["name"] == payload.get("name") else: assert res["message"] == expected_message @@ -92,7 +92,7 @@ class TestChatAssistantUpdate: def test_avatar(self, HttpApiAuth, add_chat_assistants_func, tmp_path): dataset_id, _, chat_assistant_ids = add_chat_assistants_func fn = create_image_file(tmp_path / "ragflow_test.png") - payload = {"name": "avatar_test", "icon": encode_avatar(fn), "dataset_ids": [dataset_id]} + payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": [dataset_id]} res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) assert res["code"] == 0 @@ -101,8 +101,8 @@ class TestChatAssistantUpdate: "llm, expected_code, expected_message", [ ({}, 0, ""), - ({"llm_id": "glm-4"}, 0, ""), - ({"llm_id": "unknown"}, 102, "`llm_id` unknown doesn't exist"), + ({"model_name": "glm-4"}, 0, ""), + ({"model_name": "unknown"}, 102, "`model_name` unknown doesn't exist"), ({"temperature": 0}, 0, ""), ({"temperature": 1}, 0, ""), pytest.param({"temperature": -1}, 0, "", marks=pytest.mark.skip), @@ -133,23 +133,23 @@ class TestChatAssistantUpdate: ) def test_llm(self, HttpApiAuth, add_chat_assistants_func, chat_assistant_llm_model_type, llm, expected_code, expected_message): dataset_id, _, chat_assistant_ids = add_chat_assistants_func - llm_setting = {k: v for k, v in llm.items() if k != "llm_id"} - llm_setting.setdefault("model_type", chat_assistant_llm_model_type) - - payload = {"name": "llm_test", "dataset_ids": [dataset_id]} - if "llm_id" in llm: - payload["llm_id"] = llm["llm_id"] - payload["llm_setting"] = llm_setting - + llm_payload = dict(llm) + llm_payload.setdefault("model_type", chat_assistant_llm_model_type) + payload = {"name": "llm_test", "dataset_ids": [dataset_id], "llm": llm_payload} res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) assert res["code"] == expected_code if expected_code == 0: - res = get_chat_assistant(HttpApiAuth, chat_assistant_ids[0]) - for k, v in llm.items(): - if k == "llm_id": - assert res["data"]["llm_id"] == v - else: - assert res["data"]["llm_setting"][k] == v + res = list_chat_assistants(HttpApiAuth, {"id": chat_assistant_ids[0]}) + if llm: + for k, v in llm.items(): + assert res["data"][0]["llm"][k] == v + else: + assert res["data"][0]["llm"]["model_name"] == "glm-4-flash@ZHIPU-AI" + assert res["data"][0]["llm"]["temperature"] == 0.1 + assert res["data"][0]["llm"]["top_p"] == 0.3 + assert res["data"][0]["llm"]["presence_penalty"] == 0.4 + assert res["data"][0]["llm"]["frequency_penalty"] == 0.7 + assert res["data"][0]["llm"]["max_tokens"] == 512 else: assert expected_message in res["message"] @@ -157,18 +157,18 @@ class TestChatAssistantUpdate: @pytest.mark.parametrize( "prompt, expected_code, expected_message", [ - ({}, 0, ""), + ({}, 100, "ValueError"), ({"similarity_threshold": 0}, 0, ""), ({"similarity_threshold": 1}, 0, ""), pytest.param({"similarity_threshold": -1}, 0, "", marks=pytest.mark.skip), pytest.param({"similarity_threshold": 10}, 0, "", marks=pytest.mark.skip), pytest.param({"similarity_threshold": "a"}, 0, "", marks=pytest.mark.skip), - ({"vector_similarity_weight": 0}, 0, ""), - ({"vector_similarity_weight": 1}, 0, ""), - pytest.param({"vector_similarity_weight": -1}, 0, "", marks=pytest.mark.skip), - pytest.param({"vector_similarity_weight": 10}, 0, "", marks=pytest.mark.skip), - pytest.param({"vector_similarity_weight": "a"}, 0, "", marks=pytest.mark.skip), - ({"parameters": []}, 0, ""), + ({"keywords_similarity_weight": 0}, 0, ""), + ({"keywords_similarity_weight": 1}, 0, ""), + pytest.param({"keywords_similarity_weight": -1}, 0, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": 10}, 0, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": "a"}, 0, "", marks=pytest.mark.skip), + ({"variables": []}, 0, ""), ({"top_n": 0}, 0, ""), ({"top_n": 1}, 0, ""), pytest.param({"top_n": -1}, 0, "", marks=pytest.mark.skip), @@ -181,52 +181,52 @@ class TestChatAssistantUpdate: pytest.param({"empty_response": 123}, 0, "", marks=pytest.mark.skip), pytest.param({"empty_response": True}, 0, "", marks=pytest.mark.skip), pytest.param({"empty_response": " "}, 0, "", marks=pytest.mark.skip), - ({"prologue": "Hello World"}, 0, ""), - ({"prologue": ""}, 0, ""), - ({"prologue": "!@#$%^&*()"}, 0, ""), - ({"prologue": "中文测试"}, 0, ""), - pytest.param({"prologue": 123}, 0, "", marks=pytest.mark.skip), - pytest.param({"prologue": True}, 0, "", marks=pytest.mark.skip), - pytest.param({"prologue": " "}, 0, "", marks=pytest.mark.skip), - ({"quote": True}, 0, ""), - ({"quote": False}, 0, ""), - ({"system": "Hello World {knowledge}"}, 0, ""), - ({"system": "{knowledge}"}, 0, ""), - ({"system": "!@#$%^&*() {knowledge}"}, 0, ""), - ({"system": "中文测试 {knowledge}"}, 0, ""), - ({"system": "Hello World"}, 102, "Parameter 'knowledge' is not used"), - ({"system": "Hello World", "parameters": []}, 0, ""), - pytest.param({"system": 123}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), - pytest.param({"system": True}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + ({"opener": "Hello World"}, 0, ""), + ({"opener": ""}, 0, ""), + ({"opener": "!@#$%^&*()"}, 0, ""), + ({"opener": "中文测试"}, 0, ""), + pytest.param({"opener": 123}, 0, "", marks=pytest.mark.skip), + pytest.param({"opener": True}, 0, "", marks=pytest.mark.skip), + pytest.param({"opener": " "}, 0, "", marks=pytest.mark.skip), + ({"show_quote": True}, 0, ""), + ({"show_quote": False}, 0, ""), + ({"prompt": "Hello World {knowledge}"}, 0, ""), + ({"prompt": "{knowledge}"}, 0, ""), + ({"prompt": "!@#$%^&*() {knowledge}"}, 0, ""), + ({"prompt": "中文测试 {knowledge}"}, 0, ""), + ({"prompt": "Hello World"}, 102, "Parameter 'knowledge' is not used"), + ({"prompt": "Hello World", "variables": []}, 0, ""), + pytest.param({"prompt": 123}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"prompt": True}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip), ], ) def test_prompt(self, HttpApiAuth, add_chat_assistants_func, prompt, expected_code, expected_message): dataset_id, _, chat_assistant_ids = add_chat_assistants_func - - _PROMPT_CONFIG_KEYS = {"prologue", "quote", "system", "parameters", "empty_response"} - - payload = {"name": "prompt_test", "dataset_ids": [dataset_id]} - prompt_config = {} - for k, v in prompt.items(): - if k in _PROMPT_CONFIG_KEYS: - prompt_config[k] = v - else: - payload[k] = v - if prompt_config: - payload["prompt_config"] = prompt_config - + payload = {"name": "prompt_test", "dataset_ids": [dataset_id], "prompt": prompt} res = update_chat_assistant(HttpApiAuth, chat_assistant_ids[0], payload) assert res["code"] == expected_code if expected_code == 0: - if not prompt: - return - res = get_chat_assistant(HttpApiAuth, chat_assistant_ids[0]) - for k, v in prompt.items(): - if k in _PROMPT_CONFIG_KEYS: - assert res["data"]["prompt_config"][k] == v - else: - assert res["data"][k] == v + res = list_chat_assistants(HttpApiAuth, {"id": chat_assistant_ids[0]}) + if prompt: + for k, v in prompt.items(): + if k == "keywords_similarity_weight": + assert res["data"][0]["prompt"][k] == 1 - v + else: + assert res["data"][0]["prompt"][k] == v + else: + assert res["data"]["prompt"][0]["similarity_threshold"] == 0.2 + assert res["data"]["prompt"][0]["keywords_similarity_weight"] == 0.7 + assert res["data"]["prompt"][0]["top_n"] == 6 + assert res["data"]["prompt"][0]["variables"] == [{"key": "knowledge", "optional": False}] + assert res["data"]["prompt"][0]["rerank_model"] == "" + assert res["data"]["prompt"][0]["empty_response"] == "Sorry! No relevant content was found in the knowledge base!" + assert res["data"]["prompt"][0]["opener"] == "Hi! I'm your assistant. What can I do for you?" + assert res["data"]["prompt"][0]["show_quote"] is True + assert ( + res["data"]["prompt"][0]["prompt"] + == 'You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.' + ) else: assert expected_message in res["message"] @@ -235,54 +235,50 @@ class TestChatAssistantUpdate: dataset_id, _, chat_assistant_ids = add_chat_assistants_func chat_id = chat_assistant_ids[0] - # Auth: non-owned chat returns 109 "No authorization." - res = patch_chat_assistant(HttpApiAuth, "invalid-chat-id", {"name": "anything"}) - assert res["code"] == 109 - assert res["message"] == "No authorization." + res = update_chat_assistant(HttpApiAuth, "invalid-chat-id", {"name": "anything"}) + assert res["code"] == 102 + assert res["message"] == "You do not own the chat" - # PATCH: toggle quote via prompt_config - res = patch_chat_assistant(HttpApiAuth, chat_id, {"prompt_config": {"quote": False}}) + res = update_chat_assistant(HttpApiAuth, chat_id, {"show_quotation": False, "dataset_ids": [dataset_id]}) assert res["code"] == 0 - # PATCH: invalid llm_id - res = patch_chat_assistant( + res = update_chat_assistant( HttpApiAuth, chat_id, - {"llm_id": "unknown-llm-model", "llm_setting": {"model_type": chat_assistant_llm_model_type}}, + {"llm": {"model_name": "unknown-llm-model", "model_type": chat_assistant_llm_model_type}}, ) assert res["code"] == 102 - assert "`llm_id` unknown-llm-model doesn't exist" in res["message"] + assert "`model_name` unknown-llm-model doesn't exist" in res["message"] - # PATCH: invalid rerank_id - res = patch_chat_assistant(HttpApiAuth, chat_id, {"rerank_id": "unknown-rerank-model"}) + res = update_chat_assistant( + HttpApiAuth, + chat_id, + {"prompt": {"rerank_model": "unknown-rerank-model"}}, + ) assert res["code"] == 102 - assert "`rerank_id` unknown-rerank-model doesn't exist" in res["message"] + assert "`rerank_model` unknown-rerank-model doesn't exist" in res["message"] - # PATCH: empty name - res = patch_chat_assistant(HttpApiAuth, chat_id, {"name": ""}) + res = update_chat_assistant(HttpApiAuth, chat_id, {"name": ""}) assert res["code"] == 102 assert res["message"] == "`name` cannot be empty." - # PATCH: duplicate name - res = patch_chat_assistant(HttpApiAuth, chat_id, {"name": "test_chat_assistant_1"}) + res = update_chat_assistant(HttpApiAuth, chat_id, {"name": "test_chat_assistant_1"}) assert res["code"] == 102 - assert res["message"] == "Duplicated chat name." + assert res["message"] == "Duplicated chat name in updating chat." - # PATCH: prompt_config with unused parameter - res = patch_chat_assistant( + res = update_chat_assistant( HttpApiAuth, chat_id, - {"prompt_config": {"system": "No required placeholder", "parameters": [{"key": "knowledge", "optional": False}]}}, + {"prompt": {"prompt": "No required placeholder", "variables": [{"key": "knowledge", "optional": False}]}}, ) assert res["code"] == 102 assert "Parameter 'knowledge' is not used" in res["message"] - # PATCH: icon (was "avatar" in old SDK) - res = patch_chat_assistant(HttpApiAuth, chat_id, {"icon": "raw-avatar-value"}) + res = update_chat_assistant(HttpApiAuth, chat_id, {"avatar": "raw-avatar-value"}) assert res["code"] == 0 - listed = get_chat_assistant(HttpApiAuth, chat_id) + listed = list_chat_assistants(HttpApiAuth, {"id": chat_id}) assert listed["code"] == 0 - assert listed["data"]["icon"] == "raw-avatar-value" + assert listed["data"][0]["avatar"] == "raw-avatar-value" @pytest.mark.p2 def test_update_unparsed_dataset_guard_p2(self, HttpApiAuth, add_dataset_func, clear_chat_assistants): @@ -291,6 +287,6 @@ class TestChatAssistantUpdate: assert create_res["code"] == 0 chat_id = create_res["data"]["id"] - res = patch_chat_assistant(HttpApiAuth, chat_id, {"dataset_ids": [dataset_id]}) + res = update_chat_assistant(HttpApiAuth, chat_id, {"dataset_ids": [dataset_id]}) assert res["code"] == 102 assert "doesn't own parsed file" in res["message"] diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py index c8861b0b4..7634e1e65 100644 --- a/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py @@ -14,8 +14,11 @@ # limitations under the License. # +from operator import attrgetter + import pytest from configs import CHAT_ASSISTANT_NAME_LIMIT +from ragflow_sdk import Chat from utils import encode_avatar from utils.file_utils import create_image_file @@ -73,16 +76,18 @@ class TestChatAssistantCreate: assert chat_assistant.name == "ragflow test" @pytest.mark.p3 - def test_icon(self, client, tmp_path): + def test_avatar(self, client, tmp_path): fn = create_image_file(tmp_path / "ragflow_test.png") - chat_assistant = client.create_chat(name="icon_test", icon=encode_avatar(fn), dataset_ids=[]) - assert chat_assistant.name == "icon_test" + chat_assistant = client.create_chat(name="avatar_test", avatar=encode_avatar(fn), dataset_ids=[]) + assert chat_assistant.name == "avatar_test" @pytest.mark.p3 @pytest.mark.parametrize( - "llm_setting, expected_message", + "llm, expected_message", [ ({}, ""), + ({"model_name": "glm-4"}, ""), + ({"model_name": "unknown"}, "`model_name` unknown doesn't exist"), ({"temperature": 0}, ""), ({"temperature": 1}, ""), pytest.param({"temperature": -1}, "", marks=pytest.mark.skip), @@ -111,41 +116,47 @@ class TestChatAssistantCreate: pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), ], ) - def test_llm_setting(self, client, add_chunks, llm_setting, expected_message): + def test_llm(self, client, add_chunks, llm, expected_message): dataset, _, _ = add_chunks + llm_o = Chat.LLM(client, llm) if expected_message: with pytest.raises(Exception) as exception_info: - client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm_setting=llm_setting or None) + client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm=llm_o) assert expected_message in str(exception_info.value) else: - chat_assistant = client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm_setting=llm_setting or None) - for k, v in llm_setting.items(): - assert getattr(chat_assistant.llm_setting, k) == v + chat_assistant = client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm=llm_o) + if llm: + for k, v in llm.items(): + assert attrgetter(k)(chat_assistant.llm) == v + else: + assert attrgetter("model_name")(chat_assistant.llm) == "glm-4-flash@ZHIPU-AI" + assert attrgetter("temperature")(chat_assistant.llm) == 0.1 + assert attrgetter("top_p")(chat_assistant.llm) == 0.3 + assert attrgetter("presence_penalty")(chat_assistant.llm) == 0.4 + assert attrgetter("frequency_penalty")(chat_assistant.llm) == 0.7 + assert attrgetter("max_tokens")(chat_assistant.llm) == 512 @pytest.mark.p3 @pytest.mark.parametrize( - "llm_id, expected_message", - [ - ("glm-4", ""), - ("unknown", "`llm_id` unknown doesn't exist"), - ], - ) - def test_llm_id(self, client, add_chunks, llm_id, expected_message): - dataset, _, _ = add_chunks - - if expected_message: - with pytest.raises(Exception) as exception_info: - client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm_id=llm_id) - assert expected_message in str(exception_info.value) - else: - chat_assistant = client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm_id=llm_id) - assert chat_assistant.llm_id == llm_id - - @pytest.mark.p3 - @pytest.mark.parametrize( - "prompt_config, expected_message", + "prompt, expected_message", [ + ({"similarity_threshold": 0}, ""), + ({"similarity_threshold": 1}, ""), + pytest.param({"similarity_threshold": -1}, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": 10}, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": "a"}, "", marks=pytest.mark.skip), + ({"keywords_similarity_weight": 0}, ""), + ({"keywords_similarity_weight": 1}, ""), + pytest.param({"keywords_similarity_weight": -1}, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": 10}, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": "a"}, "", marks=pytest.mark.skip), + ({"variables": []}, ""), + ({"top_n": 0}, ""), + ({"top_n": 1}, ""), + pytest.param({"top_n": -1}, "", marks=pytest.mark.skip), + pytest.param({"top_n": 10}, "", marks=pytest.mark.skip), + pytest.param({"top_n": "a"}, "", marks=pytest.mark.skip), ({"empty_response": "Hello World"}, ""), ({"empty_response": ""}, ""), ({"empty_response": "!@#$%^&*()"}, ""), @@ -153,36 +164,55 @@ class TestChatAssistantCreate: pytest.param({"empty_response": 123}, "", marks=pytest.mark.skip), pytest.param({"empty_response": True}, "", marks=pytest.mark.skip), pytest.param({"empty_response": " "}, "", marks=pytest.mark.skip), - ({"prologue": "Hello World"}, ""), - ({"prologue": ""}, ""), - ({"prologue": "!@#$%^&*()"}, ""), - ({"prologue": "中文测试"}, ""), - pytest.param({"prologue": 123}, "", marks=pytest.mark.skip), - pytest.param({"prologue": True}, "", marks=pytest.mark.skip), - pytest.param({"prologue": " "}, "", marks=pytest.mark.skip), - ({"quote": True}, ""), - ({"quote": False}, ""), - ({"system": "Hello World {knowledge}"}, ""), - ({"system": "{knowledge}"}, ""), - ({"system": "!@#$%^&*() {knowledge}"}, ""), - ({"system": "中文测试 {knowledge}"}, ""), - ({"system": "Hello World"}, ""), - ({"system": "Hello World", "parameters": []}, ""), - pytest.param({"system": 123}, "", marks=pytest.mark.skip), + ({"opener": "Hello World"}, ""), + ({"opener": ""}, ""), + ({"opener": "!@#$%^&*()"}, ""), + ({"opener": "中文测试"}, ""), + pytest.param({"opener": 123}, "", marks=pytest.mark.skip), + pytest.param({"opener": True}, "", marks=pytest.mark.skip), + pytest.param({"opener": " "}, "", marks=pytest.mark.skip), + ({"show_quote": True}, ""), + ({"show_quote": False}, ""), + ({"prompt": "Hello World {knowledge}"}, ""), + ({"prompt": "{knowledge}"}, ""), + ({"prompt": "!@#$%^&*() {knowledge}"}, ""), + ({"prompt": "中文测试 {knowledge}"}, ""), + ({"prompt": "Hello World"}, ""), + ({"prompt": "Hello World", "variables": []}, ""), + pytest.param({"prompt": 123}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"prompt": True}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), ], ) - def test_prompt_config(self, client, add_chunks, prompt_config, expected_message): + def test_prompt(self, client, add_chunks, prompt, expected_message): dataset, _, _ = add_chunks + prompt_o = Chat.Prompt(client, prompt) if expected_message: with pytest.raises(Exception) as exception_info: - client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt_config=prompt_config) + client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt=prompt_o) assert expected_message in str(exception_info.value) else: - chat_assistant = client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt_config=prompt_config) - for k, v in prompt_config.items(): - assert getattr(chat_assistant.prompt_config, k) == v + chat_assistant = client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt=prompt_o) + if prompt: + for k, v in prompt.items(): + if k == "keywords_similarity_weight": + assert attrgetter(k)(chat_assistant.prompt) == 1 - v + else: + assert attrgetter(k)(chat_assistant.prompt) == v + else: + assert attrgetter("similarity_threshold")(chat_assistant.prompt) == 0.2 + assert attrgetter("keywords_similarity_weight")(chat_assistant.prompt) == 0.7 + assert attrgetter("top_n")(chat_assistant.prompt) == 6 + assert attrgetter("variables")(chat_assistant.prompt) == [{"key": "knowledge", "optional": False}] + assert attrgetter("rerank_model")(chat_assistant.prompt) == "" + assert attrgetter("empty_response")(chat_assistant.prompt) == "Sorry! No relevant content was found in the knowledge base!" + assert attrgetter("opener")(chat_assistant.prompt) == "Hi! I'm your assistant. What can I do for you?" + assert attrgetter("show_quote")(chat_assistant.prompt) is True + assert ( + attrgetter("prompt")(chat_assistant.prompt) + == 'You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.' + ) class TestChatAssistantCreate2: diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_list_chat_assistants.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_list_chat_assistants.py index afff2cc86..b78faebaa 100644 --- a/test/testcases/test_sdk_api/test_chat_assistant_management/test_list_chat_assistants.py +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_list_chat_assistants.py @@ -136,83 +136,75 @@ class TestChatAssistantsList: @pytest.mark.parametrize( "params, expected_num, expected_message", [ - ({"keywords": None}, 5, ""), - ({"keywords": ""}, 5, ""), - ({"keywords": "test_chat_assistant_1"}, 1, ""), - ({"keywords": "unknown"}, 0, ""), + ({"name": None}, 5, ""), + ({"name": ""}, 5, ""), + ({"name": "test_chat_assistant_1"}, 1, ""), + ({"name": "unknown"}, 0, "The chat doesn't exist"), ], ) - def test_keywords(self, client, params, expected_num, expected_message): + def test_name(self, client, params, expected_num, expected_message): if expected_message: with pytest.raises(Exception) as exception_info: client.list_chats(**params) assert expected_message in str(exception_info.value) else: assistants = client.list_chats(**params) - if params["keywords"] in [None, ""]: + if params["name"] in [None, ""]: assert len(assistants) == expected_num else: - assert len(assistants) == expected_num - if expected_num: - assert assistants[0].name == params["keywords"] - - @pytest.mark.p1 - def test_exact_id_and_name_filters(self, client, add_chat_assistants): - _, _, chat_assistants = add_chat_assistants - target = chat_assistants[1] - - assistants = client.list_chats(id=target.id) - assert len(assistants) == 1 - assert assistants[0].id == target.id - - assistants = client.list_chats(name=target.name) - assert len(assistants) == 1 - assert assistants[0].name == target.name - - assistants = client.list_chats(name=target.name, keywords="unknown") - assert len(assistants) == 1 - assert assistants[0].name == target.name + assert assistants[0].name == params["name"] @pytest.mark.p1 @pytest.mark.parametrize( - "chat_assistant_id, expected_message", + "chat_assistant_id, expected_num, expected_message", [ - (lambda r: r[0], ""), - ("unknown", "No authorization."), + (None, 5, ""), + ("", 5, ""), + (lambda r: r[0], 1, ""), + ("unknown", 0, "The chat doesn't exist"), ], ) - def test_get_chat(self, client, add_chat_assistants, chat_assistant_id, expected_message): + def test_id(self, client, add_chat_assistants, chat_assistant_id, expected_num, expected_message): _, _, chat_assistants = add_chat_assistants - chat_id = chat_assistant_id([chat.id for chat in chat_assistants]) if callable(chat_assistant_id) else chat_assistant_id + if callable(chat_assistant_id): + params = {"id": chat_assistant_id([chat.id for chat in chat_assistants])} + else: + params = {"id": chat_assistant_id} if expected_message: with pytest.raises(Exception) as exception_info: - client.get_chat(chat_id) + client.list_chats(**params) assert expected_message in str(exception_info.value) else: - assistant = client.get_chat(chat_id) - assert assistant.id == chat_id + assistants = client.list_chats(**params) + if params["id"] in [None, ""]: + assert len(assistants) == expected_num + else: + assert assistants[0].id == params["id"] @pytest.mark.p3 @pytest.mark.parametrize( - "chat_assistant_id, keywords, expected_num, expected_message", + "chat_assistant_id, name, expected_num, expected_message", [ (lambda r: r[0], "test_chat_assistant_0", 1, ""), - (lambda r: r[0], "test_chat_assistant_1", 0, ""), - (lambda r: r[0], "unknown", 0, ""), + (lambda r: r[0], "test_chat_assistant_1", 0, "The chat doesn't exist"), + (lambda r: r[0], "unknown", 0, "The chat doesn't exist"), + ("id", "chat_assistant_0", 0, "The chat doesn't exist"), ], ) - def test_get_and_keywords_are_separate_lookups(self, client, add_chat_assistants, chat_assistant_id, keywords, expected_num, expected_message): + def test_name_and_id(self, client, add_chat_assistants, chat_assistant_id, name, expected_num, expected_message): _, _, chat_assistants = add_chat_assistants - chat_id = chat_assistant_id([chat.id for chat in chat_assistants]) if callable(chat_assistant_id) else chat_assistant_id + if callable(chat_assistant_id): + params = {"id": chat_assistant_id([chat.id for chat in chat_assistants]), "name": name} + else: + params = {"id": chat_assistant_id, "name": name} if expected_message: with pytest.raises(Exception) as exception_info: - client.get_chat(chat_id) + client.list_chats(**params) assert expected_message in str(exception_info.value) else: - client.get_chat(chat_id) - assistants = client.list_chats(keywords=keywords) + assistants = client.list_chats(**params) assert len(assistants) == expected_num @pytest.mark.p3 diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py index 31ceaea9f..7652c266e 100644 --- a/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from operator import attrgetter import pytest from configs import CHAT_ASSISTANT_NAME_LIMIT +from ragflow_sdk import Chat from utils import encode_avatar from utils.file_utils import create_image_file class TestChatAssistantUpdate: @pytest.mark.p2 - def test_update_rejects_non_dict(self, add_chat_assistants_func): + def test_update_rejects_non_dict_and_empty_llm_prompt(self, add_chat_assistants_func): _, _, chat_assistants = add_chat_assistants_func chat_assistant = chat_assistants[0] @@ -30,6 +32,14 @@ class TestChatAssistantUpdate: chat_assistant.update.__wrapped__(chat_assistant, "bad") assert "`update_message` must be a dict" in str(exception_info.value) + with pytest.raises(Exception) as exception_info: + chat_assistant.update({"llm": {}}) + assert "`llm` cannot be empty" in str(exception_info.value) + + with pytest.raises(Exception) as exception_info: + chat_assistant.update({"prompt": {}}) + assert "`prompt` cannot be empty" in str(exception_info.value) + @pytest.mark.p2 def test_update_raises_on_nonzero_response(self, add_chat_assistants_func, monkeypatch): _, _, chat_assistants = add_chat_assistants_func @@ -39,35 +49,12 @@ class TestChatAssistantUpdate: def json(self): return {"code": 1, "message": "boom"} - monkeypatch.setattr(chat_assistant, "patch", lambda *_args, **_kwargs: _DummyResponse()) + monkeypatch.setattr(chat_assistant, "put", lambda *_args, **_kwargs: _DummyResponse()) with pytest.raises(Exception) as exception_info: chat_assistant.update({"name": "error-case"}) assert "boom" in str(exception_info.value) - @pytest.mark.p1 - def test_update_uses_patch_for_partial_payload(self, add_chat_assistants_func, monkeypatch): - _, _, chat_assistants = add_chat_assistants_func - chat_assistant = chat_assistants[0] - captured = {} - - class _DummyResponse: - def json(self): - return {"code": 0, "message": "ok"} - - def _patch(path, payload): - captured["path"] = path - captured["payload"] = payload - return _DummyResponse() - - monkeypatch.setattr(chat_assistant, "patch", _patch) - monkeypatch.setattr(chat_assistant, "put", lambda *_args, **_kwargs: pytest.fail("update() should not use PUT")) - - chat_assistant.update({"name": "renamed"}) - - assert captured["path"] == f"/chats/{chat_assistant.id}" - assert captured["payload"] == {"name": "renamed"} - @pytest.mark.parametrize( "payload, expected_message", [ @@ -89,28 +76,29 @@ class TestChatAssistantUpdate: assert expected_message in str(exception_info.value) else: chat_assistant.update(payload) - updated_chat = client.get_chat(chat_assistant.id) + updated_chat = client.list_chats(id=chat_assistant.id)[0] assert updated_chat.name == payload["name"], str(updated_chat) @pytest.mark.p3 - def test_icon(self, client, add_chat_assistants_func, tmp_path): + def test_avatar(self, client, add_chat_assistants_func, tmp_path): dataset, _, chat_assistants = add_chat_assistants_func chat_assistant = chat_assistants[0] fn = create_image_file(tmp_path / "ragflow_test.png") - payload = {"name": "icon_test", "icon": encode_avatar(fn), "dataset_ids": [dataset.id]} + payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": [dataset.id]} chat_assistant.update(payload) - updated_chat = client.get_chat(chat_assistant.id) + updated_chat = client.list_chats(id=chat_assistant.id)[0] assert updated_chat.name == payload["name"], str(updated_chat) - assert updated_chat.icon is not None, str(updated_chat) + assert updated_chat.avatar is not None, str(updated_chat) @pytest.mark.p3 @pytest.mark.parametrize( - "llm_setting, expected_message", + "llm, expected_message", [ + ({}, "ValueError"), ({"model_name": "glm-4"}, ""), - ({"model_name": "unknown"}, "`llm_id` unknown doesn't exist"), + ({"model_name": "unknown"}, "`model_name` unknown doesn't exist"), ({"temperature": 0}, ""), ({"temperature": 1}, ""), pytest.param({"temperature": -1}, "", marks=pytest.mark.skip), @@ -139,13 +127,10 @@ class TestChatAssistantUpdate: pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), ], ) - def test_llm_setting(self, client, add_chat_assistants_func, llm_setting, expected_message): + def test_llm(self, client, add_chat_assistants_func, llm, expected_message): dataset, _, chat_assistants = add_chat_assistants_func chat_assistant = chat_assistants[0] - llm_id = llm_setting.pop("model_name", None) - payload = {"name": "llm_test", "dataset_ids": [dataset.id], "llm_setting": llm_setting} - if llm_id is not None: - payload["llm_id"] = llm_id + payload = {"name": "llm_test", "llm": llm, "dataset_ids": [dataset.id]} if expected_message: with pytest.raises(Exception) as exception_info: @@ -153,16 +138,45 @@ class TestChatAssistantUpdate: assert expected_message in str(exception_info.value) else: chat_assistant.update(payload) - updated_chat = client.get_chat(chat_assistant.id) - if llm_id: - assert updated_chat.llm_id == llm_id, str(updated_chat) - for k, v in llm_setting.items(): - assert getattr(updated_chat.llm_setting, k) == v, str(updated_chat) + updated_chat = client.list_chats(id=chat_assistant.id)[0] + if llm: + for k, v in llm.items(): + assert attrgetter(k)(updated_chat.llm) == v, str(updated_chat) + else: + excepted_value = Chat.LLM( + client, + { + "model_name": "glm-4-flash@ZHIPU-AI", + "temperature": 0.1, + "top_p": 0.3, + "presence_penalty": 0.4, + "frequency_penalty": 0.7, + "max_tokens": 512, + }, + ) + assert str(updated_chat.llm) == str(excepted_value), str(updated_chat) @pytest.mark.p3 @pytest.mark.parametrize( - "prompt_config, expected_message", + "prompt, expected_message", [ + ({}, "ValueError"), + ({"similarity_threshold": 0}, ""), + ({"similarity_threshold": 1}, ""), + pytest.param({"similarity_threshold": -1}, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": 10}, "", marks=pytest.mark.skip), + pytest.param({"similarity_threshold": "a"}, "", marks=pytest.mark.skip), + ({"keywords_similarity_weight": 0}, ""), + ({"keywords_similarity_weight": 1}, ""), + pytest.param({"keywords_similarity_weight": -1}, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": 10}, "", marks=pytest.mark.skip), + pytest.param({"keywords_similarity_weight": "a"}, "", marks=pytest.mark.skip), + ({"variables": []}, ""), + ({"top_n": 0}, ""), + ({"top_n": 1}, ""), + pytest.param({"top_n": -1}, "", marks=pytest.mark.skip), + pytest.param({"top_n": 10}, "", marks=pytest.mark.skip), + pytest.param({"top_n": "a"}, "", marks=pytest.mark.skip), ({"empty_response": "Hello World"}, ""), ({"empty_response": ""}, ""), ({"empty_response": "!@#$%^&*()"}, ""), @@ -170,29 +184,30 @@ class TestChatAssistantUpdate: pytest.param({"empty_response": 123}, "", marks=pytest.mark.skip), pytest.param({"empty_response": True}, "", marks=pytest.mark.skip), pytest.param({"empty_response": " "}, "", marks=pytest.mark.skip), - ({"prologue": "Hello World"}, ""), - ({"prologue": ""}, ""), - ({"prologue": "!@#$%^&*()"}, ""), - ({"prologue": "中文测试"}, ""), - pytest.param({"prologue": 123}, "", marks=pytest.mark.skip), - pytest.param({"prologue": True}, "", marks=pytest.mark.skip), - pytest.param({"prologue": " "}, "", marks=pytest.mark.skip), - ({"quote": True}, ""), - ({"quote": False}, ""), - ({"system": "Hello World {knowledge}"}, ""), - ({"system": "{knowledge}"}, ""), - ({"system": "!@#$%^&*() {knowledge}"}, ""), - ({"system": "中文测试 {knowledge}"}, ""), - ({"system": "Hello World"}, ""), - ({"system": "Hello World", "parameters": []}, ""), - pytest.param({"system": 123}, "", marks=pytest.mark.skip), + ({"opener": "Hello World"}, ""), + ({"opener": ""}, ""), + ({"opener": "!@#$%^&*()"}, ""), + ({"opener": "中文测试"}, ""), + pytest.param({"opener": 123}, "", marks=pytest.mark.skip), + pytest.param({"opener": True}, "", marks=pytest.mark.skip), + pytest.param({"opener": " "}, "", marks=pytest.mark.skip), + ({"show_quote": True}, ""), + ({"show_quote": False}, ""), + ({"prompt": "Hello World {knowledge}"}, ""), + ({"prompt": "{knowledge}"}, ""), + ({"prompt": "!@#$%^&*() {knowledge}"}, ""), + ({"prompt": "中文测试 {knowledge}"}, ""), + ({"prompt": "Hello World"}, ""), + ({"prompt": "Hello World", "variables": []}, ""), + pytest.param({"prompt": 123}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), + pytest.param({"prompt": True}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), ], ) - def test_prompt_config(self, client, add_chat_assistants_func, prompt_config, expected_message): + def test_prompt(self, client, add_chat_assistants_func, prompt, expected_message): dataset, _, chat_assistants = add_chat_assistants_func chat_assistant = chat_assistants[0] - payload = {"name": "prompt_test", "prompt_config": prompt_config, "dataset_ids": [dataset.id]} + payload = {"name": "prompt_test", "prompt": prompt, "dataset_ids": [dataset.id]} if expected_message: with pytest.raises(Exception) as exception_info: @@ -200,6 +215,26 @@ class TestChatAssistantUpdate: assert expected_message in str(exception_info.value) else: chat_assistant.update(payload) - updated_chat = client.get_chat(chat_assistant.id) - for k, v in prompt_config.items(): - assert getattr(updated_chat.prompt_config, k) == v, str(updated_chat) + updated_chat = client.list_chats(id=chat_assistant.id)[0] + if prompt: + for k, v in prompt.items(): + if k == "keywords_similarity_weight": + assert attrgetter(k)(updated_chat.prompt) == 1 - v, str(updated_chat) + else: + assert attrgetter(k)(updated_chat.prompt) == v, str(updated_chat) + else: + excepted_value = Chat.LLM( + client, + { + "similarity_threshold": 0.2, + "keywords_similarity_weight": 0.7, + "top_n": 6, + "variables": [{"key": "knowledge", "optional": False}], + "rerank_model": "", + "empty_response": "Sorry! No relevant content was found in the knowledge base!", + "opener": "Hi! I'm your assistant. What can I do for you?", + "show_quote": True, + "prompt": 'You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.', + }, + ) + assert str(updated_chat.prompt) == str(excepted_value), str(updated_chat) diff --git a/test/testcases/test_web_api/common.py b/test/testcases/test_web_api/common.py index 561a93751..12455242d 100644 --- a/test/testcases/test_web_api/common.py +++ b/test/testcases/test_web_api/common.py @@ -30,6 +30,7 @@ KB_APP_URL = f"/{VERSION}/kb" DATASETS_URL = f"/api/{VERSION}/datasets" DOCUMENT_APP_URL = f"/{VERSION}/document" CHUNK_API_URL = f"/{VERSION}/chunk" +DIALOG_APP_URL = f"/{VERSION}/dialog" # SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions" # SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions" MEMORY_API_URL = f"/api/{VERSION}/memories" @@ -468,6 +469,103 @@ def batch_add_chunks(auth, doc_id, num): return chunk_ids +# DIALOG APP +def create_dialog(auth, payload=None, *, headers=HEADERS, data=None): + if payload is None: + payload = {} + url = f"{HOST_ADDRESS}{DIALOG_APP_URL}/set" + req_id = str(uuid.uuid4()) + req_headers = dict(headers) + req_headers["X-Request-ID"] = req_id + start = time.monotonic() + res = requests.post(url=url, headers=req_headers, auth=auth, json=payload, data=data) + elapsed_ms = (time.monotonic() - start) * 1000 + resp_json = None + json_error = None + try: + resp_json = res.json() + except ValueError as exc: + json_error = exc + _log_http_debug("POST", url, req_id, payload, res.status_code, res.text, resp_json, elapsed_ms) + if _http_debug_enabled(): + if not res.ok or (resp_json is not None and resp_json.get("code") != 0): + payload_summary = _redact_payload(payload) + raise AssertionError( + "HTTP helper failure: " + f"req_id={req_id} url={url} status={res.status_code} " + f"payload={payload_summary} response={res.text}" + ) + if json_error: + raise json_error + return resp_json + + +def update_dialog(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{DIALOG_APP_URL}/set", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def get_dialog(auth, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{DIALOG_APP_URL}/get", headers=headers, auth=auth, params=params) + return res.json() + + +def list_dialogs(auth, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{DIALOG_APP_URL}/list", headers=headers, auth=auth) + return res.json() + + +def delete_dialog(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{DIALOG_APP_URL}/rm", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def batch_create_dialogs(auth, num, kb_ids=None): + if kb_ids is None: + kb_ids = [] + + dialog_ids = [] + for i in range(num): + if kb_ids: + prompt_config = { + "system": "You are a helpful assistant. Use the following knowledge to answer questions: {knowledge}", + "parameters": [{"key": "knowledge", "optional": False}], + } + else: + prompt_config = { + "system": "You are a helpful assistant.", + "parameters": [], + } + payload = { + "name": f"dialog_{i}", + "description": f"Test dialog {i}", + "kb_ids": kb_ids, + "prompt_config": prompt_config, + "top_n": 6, + "top_k": 1024, + "similarity_threshold": 0.1, + "vector_similarity_weight": 0.3, + "llm_setting": {"model": "gpt-3.5-turbo", "temperature": 0.7}, + } + res = create_dialog(auth, payload) + if res is None or res.get("code") != 0: + uses_knowledge = "{knowledge}" in payload["prompt_config"]["system"] + raise AssertionError( + "batch_create_dialogs failed: " + f"res={res} kb_ids_len={len(kb_ids)} uses_knowledge={uses_knowledge}" + ) + if res["code"] == 0: + dialog_ids.append(res["data"]["id"]) + return dialog_ids + + +def delete_dialogs(auth): + res = list_dialogs(auth) + if res["code"] == 0 and res["data"]: + dialog_ids = [dialog["id"] for dialog in res["data"]] + if dialog_ids: + delete_dialog(auth, {"dialog_ids": dialog_ids}) + # MEMORY APP def create_memory(auth, payload=None): url = f"{HOST_ADDRESS}{MEMORY_API_URL}" diff --git a/test/testcases/test_web_api/test_dialog_app/conftest.py b/test/testcases/test_web_api/test_dialog_app/conftest.py new file mode 100644 index 000000000..e2f142f7b --- /dev/null +++ b/test/testcases/test_web_api/test_dialog_app/conftest.py @@ -0,0 +1,50 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import batch_create_dialogs, delete_dialogs + + +@pytest.fixture(scope="function") +def add_dialog_func(request, WebApiAuth, add_dataset_func): + def cleanup(): + delete_dialogs(WebApiAuth) + + request.addfinalizer(cleanup) + + dataset_id = add_dataset_func + return dataset_id, batch_create_dialogs(WebApiAuth, 1, [dataset_id])[0] + + +@pytest.fixture(scope="class") +def add_dialogs(request, WebApiAuth, add_dataset): + def cleanup(): + delete_dialogs(WebApiAuth) + + request.addfinalizer(cleanup) + + dataset_id = add_dataset + return dataset_id, batch_create_dialogs(WebApiAuth, 5, [dataset_id]) + + +@pytest.fixture(scope="function") +def add_dialogs_func(request, WebApiAuth, add_dataset_func): + def cleanup(): + delete_dialogs(WebApiAuth) + + request.addfinalizer(cleanup) + + dataset_id = add_dataset_func + return dataset_id, batch_create_dialogs(WebApiAuth, 5, [dataset_id]) diff --git a/test/testcases/test_web_api/test_dialog_app/test_create_dialog.py b/test/testcases/test_web_api/test_dialog_app/test_create_dialog.py new file mode 100644 index 000000000..71198d27b --- /dev/null +++ b/test/testcases/test_web_api/test_dialog_app/test_create_dialog.py @@ -0,0 +1,170 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from configs import CHAT_ASSISTANT_NAME_LIMIT, INVALID_API_TOKEN +from hypothesis import example, given, settings +from libs.auth import RAGFlowWebApiAuth +from utils.hypothesis_utils import valid_names + +from common import create_dialog + + +@pytest.mark.usefixtures("clear_dialogs") +class TestAuthorization: + @pytest.mark.p2 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + payload = {"name": "auth_test", "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = create_dialog(invalid_auth, payload) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +@pytest.mark.usefixtures("clear_dialogs") +class TestCapability: + @pytest.mark.p3 + def test_create_dialog_100(self, WebApiAuth): + for i in range(100): + payload = {"name": f"dialog_{i}", "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, f"Failed to create dialog {i}" + + @pytest.mark.p3 + def test_create_dialog_concurrent(self, WebApiAuth): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(create_dialog, WebApiAuth, {"name": f"dialog_{i}", "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + +@pytest.mark.usefixtures("clear_dialogs") +class TestDialogCreate: + @pytest.mark.p1 + @given(name=valid_names()) + @example("a" * CHAT_ASSISTANT_NAME_LIMIT) + @settings(max_examples=20) + def test_name(self, WebApiAuth, name): + payload = {"name": name, "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_code, expected_message", + [ + ("", 102, "Dialog name can't be empty."), + (" ", 102, "Dialog name can't be empty."), + ("a" * (CHAT_ASSISTANT_NAME_LIMIT + 1), 102, "Dialog name length is 256 which is larger than 255"), + (0, 102, "Dialog name must be string."), + (None, 102, "Dialog name must be string."), + ], + ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], + ) + def test_name_invalid(self, WebApiAuth, name, expected_code, expected_message): + payload = {"name": name, "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = create_dialog(WebApiAuth, payload) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + @pytest.mark.p1 + def test_prompt_config_required(self, WebApiAuth): + payload = {"name": "test_dialog"} + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 101, res + assert res["message"] == "required argument are missing: prompt_config; ", res + + @pytest.mark.p1 + def test_prompt_config_with_knowledge_no_kb(self, WebApiAuth): + payload = {"name": "test_dialog", "prompt_config": {"system": "You are a helpful assistant. Use this knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}} + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p1 + def test_prompt_config_parameter_not_used(self, WebApiAuth): + payload = {"name": "test_dialog", "prompt_config": {"system": "You are a helpful assistant.", "parameters": [{"key": "unused_param", "optional": False}]}} + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 102, res + assert "Parameter 'unused_param' is not used" in res["message"], res + + @pytest.mark.p1 + def test_create_with_kb_ids(self, WebApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = { + "name": "test_dialog_with_kb", + "kb_ids": [dataset_id], + "prompt_config": {"system": "You are a helpful assistant. Use this knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}, + } + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["kb_ids"] == [dataset_id], res + + @pytest.mark.p2 + def test_create_with_all_parameters(self, WebApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = { + "name": "comprehensive_dialog", + "description": "A comprehensive test dialog", + "icon": "🤖", + "kb_ids": [dataset_id], + "top_n": 10, + "top_k": 2048, + "rerank_id": "", + "similarity_threshold": 0.2, + "vector_similarity_weight": 0.5, + "llm_setting": {"model": "gpt-4", "temperature": 0.8, "max_tokens": 1000}, + "prompt_config": {"system": "You are a helpful assistant. Use this knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}, + } + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + data = res["data"] + assert data["name"] == "comprehensive_dialog", res + assert data["description"] == "A comprehensive test dialog", res + assert data["icon"] == "🤖", res + assert data["kb_ids"] == [dataset_id], res + assert data["top_n"] == 10, res + assert data["top_k"] == 2048, res + assert data["similarity_threshold"] == 0.2, res + assert data["vector_similarity_weight"] == 0.5, res + + @pytest.mark.p3 + def test_name_duplicated(self, WebApiAuth): + name = "duplicated_dialog" + payload = {"name": name, "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_optional_parameters(self, WebApiAuth): + payload = { + "name": "test_optional_params", + "prompt_config": {"system": "You are a helpful assistant. Optional param: {optional_param}", "parameters": [{"key": "optional_param", "optional": True}]}, + } + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res diff --git a/test/testcases/test_web_api/test_dialog_app/test_delete_dialogs.py b/test/testcases/test_web_api/test_dialog_app/test_delete_dialogs.py new file mode 100644 index 000000000..0bb339342 --- /dev/null +++ b/test/testcases/test_web_api/test_dialog_app/test_delete_dialogs.py @@ -0,0 +1,204 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_create_dialogs, create_dialog, delete_dialog, list_dialogs +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +@pytest.mark.usefixtures("clear_dialogs") +class TestAuthorization: + @pytest.mark.p2 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message, add_dialog_func): + _, dialog_id = add_dialog_func + payload = {"dialog_ids": [dialog_id]} + res = delete_dialog(invalid_auth, payload) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestDialogDelete: + @pytest.mark.p1 + def test_delete_single_dialog(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 1, res + + payload = {"dialog_ids": [dialog_id]} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"] is True, res + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 0, res + + @pytest.mark.p1 + def test_delete_multiple_dialogs(self, WebApiAuth, add_dialogs_func): + _, dialog_ids = add_dialogs_func + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + payload = {"dialog_ids": dialog_ids} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"] is True, res + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 0, res + + @pytest.mark.p1 + def test_delete_partial_dialogs(self, WebApiAuth, add_dialogs_func): + _, dialog_ids = add_dialogs_func + + dialogs_to_delete = dialog_ids[:3] + payload = {"dialog_ids": dialogs_to_delete} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"] is True, res + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 2, res + + remaining_ids = [dialog["id"] for dialog in res["data"]] + for dialog_id in dialog_ids[3:]: + assert dialog_id in remaining_ids, res + + @pytest.mark.p2 + def test_delete_nonexistent_dialog(self, WebApiAuth): + fake_dialog_id = "nonexistent_dialog_id" + payload = {"dialog_ids": [fake_dialog_id]} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 103, res + assert "Only owner of dialog authorized for this operation." in res["message"], res + + @pytest.mark.p2 + def test_delete_empty_dialog_ids(self, WebApiAuth): + payload = {"dialog_ids": []} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_delete_missing_dialog_ids(self, WebApiAuth): + payload = {} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 101, res + assert res["message"] == "required argument are missing: dialog_ids; ", res + + @pytest.mark.p2 + def test_delete_invalid_dialog_ids_format(self, WebApiAuth): + payload = {"dialog_ids": "not_a_list"} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 103, res + assert res["message"] == "Only owner of dialog authorized for this operation.", res + + @pytest.mark.p2 + def test_delete_mixed_valid_invalid_dialogs(self, WebApiAuth, add_dialog_func): + _, valid_dialog_id = add_dialog_func + invalid_dialog_id = "nonexistent_dialog_id" + + payload = {"dialog_ids": [valid_dialog_id, invalid_dialog_id]} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 103, res + assert res["message"] == "Only owner of dialog authorized for this operation.", res + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 1, res + + @pytest.mark.p3 + def test_delete_dialog_concurrent(self, WebApiAuth, add_dialogs_func): + _, dialog_ids = add_dialogs_func + + count = len(dialog_ids) + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(delete_dialog, WebApiAuth, {"dialog_ids": [dialog_id]}) for dialog_id in dialog_ids] + + responses = [future.result() for future in as_completed(futures)] + + successful_deletions = sum(1 for response in responses if response["code"] == 0) + assert successful_deletions > 0, "No dialogs were successfully deleted" + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == count - successful_deletions, res + + @pytest.mark.p3 + def test_delete_dialog_idempotent(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + + payload = {"dialog_ids": [dialog_id]} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p3 + def test_delete_large_batch_dialogs(self, WebApiAuth, add_document): + dataset_id, _ = add_document + + dialog_ids = batch_create_dialogs(WebApiAuth, 50, [dataset_id]) + assert len(dialog_ids) == 50, "Failed to create 50 dialogs" + + payload = {"dialog_ids": dialog_ids} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"] is True, res + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 0, res + + @pytest.mark.p3 + def test_delete_dialog_with_special_characters(self, WebApiAuth): + payload = {"name": "Dialog with 特殊字符 and émojis 🤖", "description": "Test dialog with special characters", "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + create_res = create_dialog(WebApiAuth, payload) + assert create_res["code"] == 0, create_res + dialog_id = create_res["data"]["id"] + + delete_payload = {"dialog_ids": [dialog_id]} + res = delete_dialog(WebApiAuth, delete_payload) + assert res["code"] == 0, res + assert res["data"] is True, res + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 0, res + + @pytest.mark.p3 + def test_delete_dialog_preserves_other_user_dialogs(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + + payload = {"dialog_ids": [dialog_id]} + res = delete_dialog(WebApiAuth, payload) + assert res["code"] == 0, res diff --git a/test/testcases/test_web_api/test_dialog_app/test_dialog_edge_cases.py b/test/testcases/test_web_api/test_dialog_app/test_dialog_edge_cases.py new file mode 100644 index 000000000..bbbc00d65 --- /dev/null +++ b/test/testcases/test_web_api/test_dialog_app/test_dialog_edge_cases.py @@ -0,0 +1,205 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import create_dialog, delete_dialog, get_dialog, update_dialog + + +@pytest.mark.usefixtures("clear_dialogs") +class TestDialogEdgeCases: + @pytest.mark.p2 + def test_create_dialog_with_tavily_api_key(self, WebApiAuth): + """Test creating dialog with Tavily API key instead of dataset""" + payload = { + "name": "tavily_dialog", + "prompt_config": {"system": "You are a helpful assistant. Use this knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}], "tavily_api_key": "test_tavily_key"}, + } + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.skip + @pytest.mark.p2 + def test_create_dialog_with_different_embedding_models(self, WebApiAuth): + """Test creating dialog with knowledge bases that have different embedding models""" + # This test would require creating datasets with different embedding models + # For now, we'll test the error case with a mock scenario + payload = { + "name": "mixed_embedding_dialog", + "kb_ids": ["kb_with_model_a", "kb_with_model_b"], + "prompt_config": {"system": "You are a helpful assistant with knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}, + } + res = create_dialog(WebApiAuth, payload) + # This should fail due to different embedding models + assert res["code"] == 102, res + assert "Datasets use different embedding models" in res["message"], res + + @pytest.mark.p2 + def test_create_dialog_with_extremely_long_system_prompt(self, WebApiAuth): + """Test creating dialog with very long system prompt""" + long_prompt = "You are a helpful assistant. " * 1000 + payload = {"name": "long_prompt_dialog", "prompt_config": {"system": long_prompt, "parameters": []}} + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_create_dialog_with_unicode_characters(self, WebApiAuth): + """Test creating dialog with Unicode characters in various fields""" + payload = { + "name": "Unicode测试对话🤖", + "description": "测试Unicode字符支持 with émojis 🚀🌟", + "icon": "🤖", + "prompt_config": {"system": "你是一个有用的助手。You are helpful. Vous êtes utile. 🌍", "parameters": []}, + } + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["name"] == "Unicode测试对话🤖", res + assert res["data"]["description"] == "测试Unicode字符支持 with émojis 🚀🌟", res + + @pytest.mark.p2 + def test_create_dialog_with_extreme_parameter_values(self, WebApiAuth): + """Test creating dialog with extreme parameter values""" + payload = { + "name": "extreme_params_dialog", + "top_n": 0, + "top_k": 1, + "similarity_threshold": 0.0, + "vector_similarity_weight": 1.0, + "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}, + } + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["top_n"] == 0, res + assert res["data"]["top_k"] == 1, res + assert res["data"]["similarity_threshold"] == 0.0, res + assert res["data"]["vector_similarity_weight"] == 1.0, res + + @pytest.mark.p2 + def test_create_dialog_with_negative_parameter_values(self, WebApiAuth): + """Test creating dialog with negative parameter values""" + payload = { + "name": "negative_params_dialog", + "top_n": -1, + "top_k": -100, + "similarity_threshold": -0.5, + "vector_similarity_weight": -0.3, + "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}, + } + res = create_dialog(WebApiAuth, payload) + assert res["code"] in [0, 102], res + + @pytest.mark.p2 + def test_update_dialog_with_empty_kb_ids(self, WebApiAuth, add_dialog_func): + """Test updating dialog to remove all knowledge bases""" + dataset_id, dialog_id = add_dialog_func + payload = {"dialog_id": dialog_id, "kb_ids": [], "prompt_config": {"system": "You are a helpful assistant without knowledge.", "parameters": []}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["kb_ids"] == [], res + + @pytest.mark.p2 + def test_update_dialog_with_null_values(self, WebApiAuth, add_dialog_func): + """Test updating dialog with null/None values""" + dataset_id, dialog_id = add_dialog_func + payload = {"dialog_id": dialog_id, "description": None, "icon": None, "rerank_id": None, "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p3 + def test_dialog_with_complex_prompt_parameters(self, WebApiAuth, add_dataset_func): + """Test dialog with complex prompt parameter configurations""" + payload = { + "name": "complex_params_dialog", + "prompt_config": { + "system": "You are {role} assistant. Use {knowledge} and consider {context}. Optional: {optional_param}", + "parameters": [{"key": "role", "optional": False}, {"key": "knowledge", "optional": True}, {"key": "context", "optional": False}, {"key": "optional_param", "optional": True}], + }, + "kb_ids": [add_dataset_func], + } + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p3 + def test_dialog_with_malformed_prompt_parameters(self, WebApiAuth): + """Test dialog with malformed prompt parameter configurations""" + payload = { + "name": "malformed_params_dialog", + "prompt_config": { + "system": "You are a helpful assistant.", + "parameters": [ + { + "key": "", + "optional": False, + }, + {"optional": True}, + { + "key": "valid_param", + }, + ], + }, + } + res = create_dialog(WebApiAuth, payload) + + assert res["code"] in [0, 102], res + + @pytest.mark.p3 + def test_dialog_operations_with_special_ids(self, WebApiAuth): + """Test dialog operations with special ID formats""" + special_ids = [ + "00000000-0000-0000-0000-000000000000", + "ffffffff-ffff-ffff-ffff-ffffffffffff", + "12345678-1234-1234-1234-123456789abc", + ] + + for special_id in special_ids: + res = get_dialog(WebApiAuth, {"dialog_id": special_id}) + assert res["code"] == 102, f"Should fail for ID: {special_id}" + + res = delete_dialog(WebApiAuth, {"dialog_ids": [special_id]}) + assert res["code"] == 103, f"Should fail for ID: {special_id}" + + @pytest.mark.p3 + def test_dialog_with_extremely_large_llm_settings(self, WebApiAuth): + """Test dialog with very large LLM settings""" + large_llm_setting = { + "model": "gpt-4", + "temperature": 0.7, + "max_tokens": 999999, + "custom_param_" + "x" * 1000: "large_value_" + "y" * 1000, + } + payload = {"name": "large_llm_settings_dialog", "llm_setting": large_llm_setting, "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = create_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + + @pytest.mark.p3 + def test_concurrent_dialog_operations(self, WebApiAuth, add_dialog_func): + """Test concurrent operations on the same dialog""" + from concurrent.futures import ThreadPoolExecutor, as_completed + + _, dialog_id = add_dialog_func + + def update_operation(i): + payload = {"dialog_id": dialog_id, "name": f"concurrent_update_{i}", "prompt_config": {"system": f"You are assistant number {i}.", "parameters": []}} + return update_dialog(WebApiAuth, payload) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(update_operation, i) for i in range(10)] + + responses = [future.result() for future in as_completed(futures)] + + successful_updates = sum(1 for response in responses if response["code"] == 0) + assert successful_updates > 0, "No updates succeeded" + + res = get_dialog(WebApiAuth, {"dialog_id": dialog_id}) + assert res["code"] == 0, res diff --git a/test/testcases/test_web_api/test_dialog_app/test_dialog_routes_unit.py b/test/testcases/test_web_api/test_dialog_app/test_dialog_routes_unit.py new file mode 100644 index 000000000..b3007a9e0 --- /dev/null +++ b/test/testcases/test_web_api/test_dialog_app/test_dialog_routes_unit.py @@ -0,0 +1,572 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import inspect +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace +from functools import wraps + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _Args(dict): + def get(self, key, default=None): + return super().get(key, default) + + +def _run(coro): + return asyncio.run(coro) + + +def _set_request_json(monkeypatch, module, payload): + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload)) + + +def _set_request_args(monkeypatch, module, args): + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args))) + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +def _load_dialog_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + quart_mod = ModuleType("quart") + quart_mod.request = SimpleNamespace(args=_Args()) + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = SimpleNamespace(id="tenant-1") + apps_mod.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + api_pkg.apps = apps_mod + + db_pkg = ModuleType("api.db") + db_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db", db_pkg) + api_pkg.db = db_pkg + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + services_pkg.duplicate_name = lambda _checker, **kwargs: kwargs.get("name", "") + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + dialog_service_mod = ModuleType("api.db.services.dialog_service") + + class _DialogService: + model = SimpleNamespace(create_time="create_time") + + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def save(**_kwargs): + return True + + @staticmethod + def update_by_id(*_args, **_kwargs): + return True + + @staticmethod + def get_by_id(_id): + return True, SimpleNamespace(to_dict=lambda: {"id": _id, "kb_ids": []}) + + @staticmethod + def get_by_tenant_ids(*_args, **_kwargs): + return [], 0 + + @staticmethod + def update_many_by_id(_payload): + return True + + dialog_service_mod.DialogService = _DialogService + monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod) + + tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") + + class _MockTableObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def to_dict(self): + return {k: v for k, v in self.__dict__.items()} + + class _TenantLLMService: + @staticmethod + def split_model_name_and_factory(embd_id): + return embd_id.split("@") + + @staticmethod + def get_api_key(tenant_id, model_name, model_type=None): + return _MockTableObject( + id=1, + tenant_id=tenant_id, + llm_factory="", + model_type="chat", + llm_name=model_name, + api_key="fake-api-key", + api_base="https://api.example.com", + max_tokens=8192, + used_tokens=0, + status=1 + ) + + tenant_llm_service_mod.TenantLLMService = _TenantLLMService + monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + + knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service") + + class _KnowledgebaseService: + @staticmethod + def get_by_ids(_ids): + return [] + + @staticmethod + def get_by_id(_id): + return False, None + + @staticmethod + def query(**_kwargs): + return [] + + knowledgebase_service_mod.KnowledgebaseService = _KnowledgebaseService + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + + class _TenantService: + @staticmethod + def get_by_id(_id): + return True, SimpleNamespace(llm_id="llm-default") + + class _UserTenantService: + @staticmethod + def query(**_kwargs): + return [SimpleNamespace(tenant_id="tenant-1")] + + user_service_mod.TenantService = _TenantService + user_service_mod.UserTenantService = _UserTenantService + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + from common.constants import RetCode + + async def _default_request_json(): + return {} + + def _get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!"): + return {"code": code, "message": message} + + def _get_json_result(code=RetCode.SUCCESS, message="success", data=None): + return {"code": code, "message": message, "data": data} + + def _server_error_response(error): + return {"code": RetCode.EXCEPTION_ERROR, "message": repr(error)} + + def _validate_request(*_args, **_kwargs): + def _decorator(func): + if inspect.iscoroutinefunction(func): + @wraps(func) + async def _wrapped(*func_args, **func_kwargs): + return await func(*func_args, **func_kwargs) + + return _wrapped + + @wraps(func) + def _wrapped(*func_args, **func_kwargs): + return func(*func_args, **func_kwargs) + + return _wrapped + + return _decorator + + api_utils_mod.get_request_json = _default_request_json + api_utils_mod.get_data_error_result = _get_data_error_result + api_utils_mod.get_json_result = _get_json_result + api_utils_mod.server_error_response = _server_error_response + api_utils_mod.validate_request = _validate_request + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + module_name = "test_dialog_routes_unit_module" + module_path = repo_root / "api" / "apps" / "dialog_app.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_set_dialog_branch_matrix_unit(monkeypatch): + module = _load_dialog_module(monkeypatch) + handler = inspect.unwrap(module.set_dialog) + + _set_request_json(monkeypatch, module, {"name": 1, "prompt_config": {"system": "", "parameters": []}}) + res = _run(handler()) + assert res["message"] == "Dialog name must be string." + + _set_request_json(monkeypatch, module, {"name": " ", "prompt_config": {"system": "", "parameters": []}}) + res = _run(handler()) + assert res["message"] == "Dialog name can't be empty." + + _set_request_json(monkeypatch, module, {"name": "a" * 256, "prompt_config": {"system": "", "parameters": []}}) + res = _run(handler()) + assert res["message"] == "Dialog name length is 256 which is larger than 255" + + captured = {} + + def _dup_name(checker, **kwargs): + assert checker(name=kwargs["name"]) is True + return kwargs["name"] + " (1)" + + monkeypatch.setattr(module, "duplicate_name", _dup_name) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(name="new dialog")]) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x", tenant_llm_id=1))) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="embd-a@builtin", tenant_embd_id=2)]) + monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@")) + monkeypatch.setattr(module.DialogService, "save", lambda **kwargs: captured.update(kwargs) or False) + _set_request_json( + monkeypatch, + module, + { + "name": "New Dialog", + "kb_ids": ["kb-1"], + "prompt_config": {"system": "Use {knowledge}", "parameters": []}, + }, + ) + res = _run(handler()) + assert res["message"] == "Fail to new a dialog!" + assert captured["name"] == "New Dialog (1)" + assert captured["prompt_config"]["parameters"] == [{"key": "knowledge", "optional": False}] + + _set_request_json( + monkeypatch, + module, + { + "dialog_id": "dialog-1", + "name": "Update", + "kb_ids": [], + "prompt_config": { + "system": "Use {knowledge}", + "parameters": [{"key": "knowledge", "optional": True}], + }, + }, + ) + res = _run(handler()) + assert "Please remove `{knowledge}` in system prompt" in res["message"] + + _set_request_json( + monkeypatch, + module, + {"name": "demo", "prompt_config": {"system": "hello", "parameters": [{"key": "must", "optional": False}]}}, + ) + res = _run(handler()) + assert "Parameter 'must' is not used" in res["message"] + + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (False, None)) + _set_request_json(monkeypatch, module, {"name": "demo", "prompt_config": {"system": "hello", "parameters": []}}) + res = _run(handler()) + assert res["message"] == "Tenant not found!" + + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x", tenant_llm_id=1))) + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue( + { + "name": "demo", + "kb_ids": ["kb-1", "kb-2"], + "prompt_config": {"system": "hello", "parameters": []}, + } + ), + ) + monkeypatch.setattr( + module.KnowledgebaseService, + "get_by_ids", + lambda _ids: [SimpleNamespace(embd_id="embd-a@f1", tenant_embd_id=2), SimpleNamespace(embd_id="embd-b@f2", tenant_embd_id=2)], + ) + monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@")) + res = _run(handler()) + assert "Datasets use different embedding models" in res["message"] + + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue( + { + "name": "optional-param-dialog", + "prompt_config": {"system": "hello", "parameters": [{"key": "ignored", "optional": True}]}, + } + ), + ) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: []) + monkeypatch.setattr(module.DialogService, "save", lambda **_kwargs: False) + res = _run(handler()) + assert res["message"] == "Fail to new a dialog!" + + monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: []) + monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: False) + _set_request_json( + monkeypatch, + module, + { + "dialog_id": "dialog-1", + "kb_names": ["legacy"], + "name": "rename", + "prompt_config": {"system": "hello", "parameters": []}, + }, + ) + res = _run(handler()) + assert res["message"] == "Dialog not found!" + + monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: True) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None)) + _set_request_json( + monkeypatch, + module, + { + "dialog_id": "dialog-1", + "name": "rename", + "prompt_config": {"system": "hello", "parameters": []}, + }, + ) + res = _run(handler()) + assert res["message"] == "Fail to update a dialog!" + + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, SimpleNamespace(to_dict=lambda: {"id": _id, "kb_ids": ["kb-1"]}))) + monkeypatch.setattr( + module.KnowledgebaseService, + "get_by_id", + lambda _id: (True, SimpleNamespace(status=module.StatusEnum.VALID.value, name="KB One")), + ) + _set_request_json( + monkeypatch, + module, + { + "dialog_id": "dialog-1", + "kb_names": ["legacy"], + "name": "new-name", + "prompt_config": {"system": "hello", "parameters": []}, + }, + ) + res = _run(handler()) + assert res["code"] == 0 + assert res["data"]["name"] == "new-name" + assert res["data"]["kb_names"] == ["KB One"] + + def _raise_tenant(_id): + raise RuntimeError("set boom") + + monkeypatch.setattr(module.TenantService, "get_by_id", _raise_tenant) + _set_request_json(monkeypatch, module, {"name": "demo", "prompt_config": {"system": "hello", "parameters": []}}) + res = _run(handler()) + assert "set boom" in res["message"] + + +@pytest.mark.p2 +def test_get_get_kb_names_and_list_dialogs_exception_matrix_unit(monkeypatch): + module = _load_dialog_module(monkeypatch) + get_handler = inspect.unwrap(module.get) + + monkeypatch.setattr( + module.DialogService, + "get_by_id", + lambda _id: (True, SimpleNamespace(to_dict=lambda: {"id": _id, "kb_ids": ["kb-1", "kb-2"]})), + ) + monkeypatch.setattr( + module.KnowledgebaseService, + "get_by_id", + lambda kid: ( + (True, SimpleNamespace(status=module.StatusEnum.VALID.value, name="KB-1")) + if kid == "kb-1" + else (False, None) + ), + ) + _set_request_args(monkeypatch, module, {"dialog_id": "dialog-1"}) + res = get_handler() + assert res["code"] == 0 + assert res["data"]["kb_ids"] == ["kb-1"] + assert res["data"]["kb_names"] == ["KB-1"] + + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None)) + _set_request_args(monkeypatch, module, {"dialog_id": "dialog-missing"}) + res = get_handler() + assert res["message"] == "Dialog not found!" + + def _raise_get(_id): + raise RuntimeError("get boom") + + monkeypatch.setattr(module.DialogService, "get_by_id", _raise_get) + _set_request_args(monkeypatch, module, {"dialog_id": "dialog-1"}) + res = get_handler() + assert "get boom" in res["message"] + + monkeypatch.setattr( + module.KnowledgebaseService, + "get_by_id", + lambda kid: ( + (True, SimpleNamespace(status=module.StatusEnum.VALID.value, name=f"KB-{kid}")) + if kid.startswith("ok") + else (True, SimpleNamespace(status=module.StatusEnum.INVALID.value, name=f"BAD-{kid}")) + ), + ) + ids, names = module.get_kb_names(["ok-1", "bad-1", "ok-2"]) + assert ids == ["ok-1", "ok-2"] + assert names == ["KB-ok-1", "KB-ok-2"] + + def _raise_list(**_kwargs): + raise RuntimeError("list boom") + + monkeypatch.setattr(module.DialogService, "query", _raise_list) + res = module.list_dialogs() + assert "list boom" in res["message"] + + +@pytest.mark.p2 +def test_list_dialogs_next_owner_desc_and_pagination_matrix_unit(monkeypatch): + module = _load_dialog_module(monkeypatch) + handler = inspect.unwrap(module.list_dialogs_next) + + calls = [] + + def _get_by_tenant_ids(tenants, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id): + calls.append( + { + "tenants": tenants, + "user_id": user_id, + "page_number": page_number, + "items_per_page": items_per_page, + "orderby": orderby, + "desc": desc, + "keywords": keywords, + "parser_id": parser_id, + } + ) + if tenants: + return ( + [ + {"id": "dialog-1", "tenant_id": "tenant-a"}, + {"id": "dialog-2", "tenant_id": "tenant-x"}, + {"id": "dialog-3", "tenant_id": "tenant-b"}, + ], + 3, + ) + return ([{"id": "dialog-0", "tenant_id": "tenant-1"}], 1) + + monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids) + + _set_request_args( + monkeypatch, + module, + { + "keywords": "k", + "page": "1", + "page_size": "2", + "parser_id": "parser-x", + "orderby": "create_time", + "desc": "false", + }, + ) + _set_request_json(monkeypatch, module, {"owner_ids": []}) + res = _run(handler()) + assert res["code"] == 0 + assert res["data"]["total"] == 1 + assert calls[-1]["tenants"] == [] + assert calls[-1]["desc"] is False + + _set_request_args(monkeypatch, module, {"page": "2", "page_size": "1"}) + _set_request_json(monkeypatch, module, {"owner_ids": ["tenant-a", "tenant-b"]}) + res = _run(handler()) + assert res["code"] == 0 + assert res["data"]["total"] == 2 + assert res["data"]["dialogs"] == [{"id": "dialog-3", "tenant_id": "tenant-b"}] + assert calls[-1]["page_number"] == 0 + assert calls[-1]["items_per_page"] == 0 + assert calls[-1]["desc"] is True + + def _raise_next(*_args, **_kwargs): + raise RuntimeError("next boom") + + monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _raise_next) + _set_request_args(monkeypatch, module, {"page": "1", "page_size": "1"}) + _set_request_json(monkeypatch, module, {"owner_ids": []}) + res = _run(handler()) + assert "next boom" in res["message"] + + +@pytest.mark.p2 +def test_rm_permission_and_exception_matrix_unit(monkeypatch): + module = _load_dialog_module(monkeypatch) + handler = inspect.unwrap(module.rm) + + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-a")]) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + _set_request_json(monkeypatch, module, {"dialog_ids": ["dialog-1"]}) + res = _run(handler()) + assert res["code"] == module.RetCode.OPERATING_ERROR + assert "Only owner of dialog authorized for this operation." in res["message"] + + def _raise_query(**_kwargs): + raise RuntimeError("rm boom") + + monkeypatch.setattr(module.DialogService, "query", _raise_query) + _set_request_json(monkeypatch, module, {"dialog_ids": ["dialog-1"]}) + res = _run(handler()) + assert "rm boom" in res["message"] diff --git a/test/testcases/test_web_api/test_dialog_app/test_get_dialog.py b/test/testcases/test_web_api/test_dialog_app/test_get_dialog.py new file mode 100644 index 000000000..1762f8043 --- /dev/null +++ b/test/testcases/test_web_api/test_dialog_app/test_get_dialog.py @@ -0,0 +1,177 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import create_dialog, get_dialog +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +@pytest.mark.usefixtures("clear_dialogs") +class TestAuthorization: + @pytest.mark.p2 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message, add_dialog_func): + _, dialog_id = add_dialog_func + res = get_dialog(invalid_auth, {"dialog_id": dialog_id}) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestDialogGet: + @pytest.mark.p1 + def test_get_existing_dialog(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + res = get_dialog(WebApiAuth, {"dialog_id": dialog_id}) + assert res["code"] == 0, res + data = res["data"] + assert data["id"] == dialog_id, res + assert "name" in data, res + assert "description" in data, res + assert "kb_ids" in data, res + assert "kb_names" in data, res + assert "prompt_config" in data, res + assert "llm_setting" in data, res + assert "top_n" in data, res + assert "top_k" in data, res + assert "similarity_threshold" in data, res + assert "vector_similarity_weight" in data, res + + @pytest.mark.p1 + def test_get_dialog_with_kb_names(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + res = get_dialog(WebApiAuth, {"dialog_id": dialog_id}) + assert res["code"] == 0, res + data = res["data"] + assert isinstance(data["kb_ids"], list), res + assert isinstance(data["kb_names"], list), res + assert len(data["kb_ids"]) == len(data["kb_names"]), res + + @pytest.mark.p2 + def test_get_nonexistent_dialog(self, WebApiAuth): + fake_dialog_id = "nonexistent_dialog_id" + res = get_dialog(WebApiAuth, {"dialog_id": fake_dialog_id}) + assert res["code"] == 102, res + assert "Dialog not found" in res["message"], res + + @pytest.mark.p2 + def test_get_dialog_missing_id(self, WebApiAuth): + res = get_dialog(WebApiAuth, {}) + assert res["code"] == 100, res + assert res["message"] == "", res + + @pytest.mark.p2 + def test_get_dialog_empty_id(self, WebApiAuth): + res = get_dialog(WebApiAuth, {"dialog_id": ""}) + assert res["code"] == 102, res + + @pytest.mark.p2 + def test_get_dialog_invalid_id_format(self, WebApiAuth): + res = get_dialog(WebApiAuth, {"dialog_id": "invalid_format"}) + assert res["code"] == 102, res + + @pytest.mark.p3 + def test_get_dialog_data_structure(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + res = get_dialog(WebApiAuth, {"dialog_id": dialog_id}) + assert res["code"] == 0, res + data = res["data"] + + required_fields = [ + "id", + "name", + "description", + "kb_ids", + "kb_names", + "prompt_config", + "llm_setting", + "top_n", + "top_k", + "similarity_threshold", + "vector_similarity_weight", + "create_time", + "update_time", + ] + for field in required_fields: + assert field in data, f"Missing field: {field}" + + assert isinstance(data["id"], str), res + assert isinstance(data["name"], str), res + assert isinstance(data["kb_ids"], list), res + assert isinstance(data["kb_names"], list), res + assert isinstance(data["prompt_config"], dict), res + assert isinstance(data["top_n"], int), res + assert isinstance(data["top_k"], int), res + assert isinstance(data["similarity_threshold"], (int, float)), res + assert isinstance(data["vector_similarity_weight"], (int, float)), res + + @pytest.mark.p3 + def test_get_dialog_prompt_config_structure(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + res = get_dialog(WebApiAuth, {"dialog_id": dialog_id}) + assert res["code"] == 0, res + + prompt_config = res["data"]["prompt_config"] + assert "system" in prompt_config, res + assert "parameters" in prompt_config, res + assert isinstance(prompt_config["system"], str), res + assert isinstance(prompt_config["parameters"], list), res + + @pytest.mark.p3 + def test_get_dialog_with_multiple_kbs(self, WebApiAuth, add_dataset_func): + dataset_id1 = add_dataset_func + dataset_id2 = add_dataset_func + + payload = { + "name": "multi_kb_dialog", + "kb_ids": [dataset_id1, dataset_id2], + "prompt_config": {"system": "You are a helpful assistant with knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}, + } + create_res = create_dialog(WebApiAuth, payload) + assert create_res["code"] == 0, create_res + dialog_id = create_res["data"]["id"] + + res = get_dialog(WebApiAuth, {"dialog_id": dialog_id}) + assert res["code"] == 0, res + data = res["data"] + assert len(data["kb_ids"]) == 2, res + assert len(data["kb_names"]) == 2, res + assert dataset_id1 in data["kb_ids"], res + assert dataset_id2 in data["kb_ids"], res + + @pytest.mark.p3 + def test_get_dialog_with_invalid_kb(self, WebApiAuth): + payload = { + "name": "invalid_kb_dialog", + "kb_ids": ["invalid_kb_id"], + "prompt_config": {"system": "You are a helpful assistant with knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}, + } + create_res = create_dialog(WebApiAuth, payload) + assert create_res["code"] == 0, create_res + dialog_id = create_res["data"]["id"] + + res = get_dialog(WebApiAuth, {"dialog_id": dialog_id}) + assert res["code"] == 0, res + data = res["data"] + + assert len(data["kb_ids"]) == 0, res + assert len(data["kb_names"]) == 0, res diff --git a/test/testcases/test_web_api/test_dialog_app/test_list_dialogs.py b/test/testcases/test_web_api/test_dialog_app/test_list_dialogs.py new file mode 100644 index 000000000..fc48b1ba4 --- /dev/null +++ b/test/testcases/test_web_api/test_dialog_app/test_list_dialogs.py @@ -0,0 +1,210 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import batch_create_dialogs, create_dialog, list_dialogs +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +@pytest.mark.usefixtures("clear_dialogs") +class TestAuthorization: + @pytest.mark.p2 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = list_dialogs(invalid_auth) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestDialogList: + @pytest.mark.p1 + @pytest.mark.usefixtures("add_dialogs_func") + def test_list_empty_dialogs(self, WebApiAuth): + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + @pytest.mark.p1 + def test_list_multiple_dialogs(self, WebApiAuth, add_dialogs_func): + _, dialog_ids = add_dialogs_func + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + returned_ids = [dialog["id"] for dialog in res["data"]] + for dialog_id in dialog_ids: + assert dialog_id in returned_ids, res + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dialogs_func") + def test_list_dialogs_data_structure(self, WebApiAuth): + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + dialog = res["data"][0] + required_fields = [ + "id", + "name", + "description", + "kb_ids", + "kb_names", + "prompt_config", + "llm_setting", + "top_n", + "top_k", + "similarity_threshold", + "vector_similarity_weight", + "create_time", + "update_time", + ] + for field in required_fields: + assert field in dialog, f"Missing field: {field}" + + assert isinstance(dialog["id"], str), res + assert isinstance(dialog["name"], str), res + assert isinstance(dialog["kb_ids"], list), res + assert isinstance(dialog["kb_names"], list), res + assert isinstance(dialog["prompt_config"], dict), res + assert isinstance(dialog["top_n"], int), res + assert isinstance(dialog["top_k"], int), res + assert isinstance(dialog["similarity_threshold"], (int, float)), res + assert isinstance(dialog["vector_similarity_weight"], (int, float)), res + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dialogs_func") + def test_list_dialogs_with_kb_names(self, WebApiAuth): + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + + dialog = res["data"][0] + assert isinstance(dialog["kb_ids"], list), res + assert isinstance(dialog["kb_names"], list), res + assert len(dialog["kb_ids"]) == len(dialog["kb_names"]), res + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_dialogs_func") + def test_list_dialogs_ordering(self, WebApiAuth): + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 5, res + + dialogs = res["data"] + for i in range(len(dialogs) - 1): + current_time = dialogs[i]["create_time"] + next_time = dialogs[i + 1]["create_time"] + assert current_time >= next_time, f"Dialogs not properly ordered: {current_time} should be >= {next_time}" + + @pytest.mark.p3 + @pytest.mark.usefixtures("clear_dialogs") + def test_list_dialogs_with_invalid_kb(self, WebApiAuth): + payload = { + "name": "invalid_kb_dialog", + "kb_ids": ["invalid_kb_id"], + "prompt_config": {"system": "You are a helpful assistant with knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}, + } + create_res = create_dialog(WebApiAuth, payload) + assert create_res["code"] == 0, create_res + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 1, res + + dialog = res["data"][0] + + assert len(dialog["kb_ids"]) == 0, res + assert len(dialog["kb_names"]) == 0, res + + @pytest.mark.p3 + @pytest.mark.usefixtures("clear_dialogs") + def test_list_dialogs_with_multiple_kbs(self, WebApiAuth, add_dataset_func): + dataset_id1 = add_dataset_func + dataset_id2 = add_dataset_func + + payload = { + "name": "multi_kb_dialog", + "kb_ids": [dataset_id1, dataset_id2], + "prompt_config": {"system": "You are a helpful assistant with knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}, + } + create_res = create_dialog(WebApiAuth, payload) + assert create_res["code"] == 0, create_res + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 1, res + + dialog = res["data"][0] + assert len(dialog["kb_ids"]) == 2, res + assert len(dialog["kb_names"]) == 2, res + assert dataset_id1 in dialog["kb_ids"], res + assert dataset_id2 in dialog["kb_ids"], res + + @pytest.mark.p3 + @pytest.mark.usefixtures("add_dialogs_func") + def test_list_dialogs_prompt_config_structure(self, WebApiAuth): + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + + dialog = res["data"][0] + prompt_config = dialog["prompt_config"] + assert "system" in prompt_config, res + assert "parameters" in prompt_config, res + assert isinstance(prompt_config["system"], str), res + assert isinstance(prompt_config["parameters"], list), res + + @pytest.mark.p3 + @pytest.mark.usefixtures("clear_dialogs") + def test_list_dialogs_performance(self, WebApiAuth, add_document): + dataset_id, _ = add_document + dialog_ids = batch_create_dialogs(WebApiAuth, 100, [dataset_id]) + assert len(dialog_ids) == 100, "Failed to create 100 dialogs" + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 100, res + + returned_ids = [dialog["id"] for dialog in res["data"]] + for dialog_id in dialog_ids: + assert dialog_id in returned_ids, f"Dialog {dialog_id} not found in list" + + @pytest.mark.p3 + @pytest.mark.usefixtures("clear_dialogs") + def test_list_dialogs_with_mixed_kb_states(self, WebApiAuth, add_dataset_func): + valid_dataset_id = add_dataset_func + + payload = { + "name": "mixed_kb_dialog", + "kb_ids": [valid_dataset_id, "invalid_kb_id"], + "prompt_config": {"system": "You are a helpful assistant with knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}, + } + create_res = create_dialog(WebApiAuth, payload) + assert create_res["code"] == 0, create_res + + res = list_dialogs(WebApiAuth) + assert res["code"] == 0, res + assert len(res["data"]) == 1, res + + dialog = res["data"][0] + assert len(dialog["kb_ids"]) == 1, res + assert dialog["kb_ids"][0] == valid_dataset_id, res + assert len(dialog["kb_names"]) == 1, res diff --git a/test/testcases/test_web_api/test_dialog_app/test_update_dialog.py b/test/testcases/test_web_api/test_dialog_app/test_update_dialog.py new file mode 100644 index 000000000..30f55b89b --- /dev/null +++ b/test/testcases/test_web_api/test_dialog_app/test_update_dialog.py @@ -0,0 +1,170 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import update_dialog +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +@pytest.mark.usefixtures("clear_dialogs") +class TestAuthorization: + @pytest.mark.p2 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ids=["empty_auth", "invalid_api_token"], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message, add_dialog_func): + _, dialog_id = add_dialog_func + payload = {"dialog_id": dialog_id, "name": "updated_name", "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = update_dialog(invalid_auth, payload) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestDialogUpdate: + @pytest.mark.p1 + def test_update_name(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + new_name = "updated_dialog_name" + payload = {"dialog_id": dialog_id, "name": new_name, "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["name"] == new_name, res + + @pytest.mark.p2 + def test_update_description(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + new_description = "Updated description" + payload = {"dialog_id": dialog_id, "description": new_description, "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["description"] == new_description, res + + @pytest.mark.p1 + def test_update_prompt_config(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + new_prompt_config = {"system": "You are an updated helpful assistant with {param1}.", "parameters": [{"key": "param1", "optional": False}]} + payload = {"dialog_id": dialog_id, "prompt_config": new_prompt_config} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["prompt_config"]["system"] == new_prompt_config["system"], res + + @pytest.mark.p1 + def test_update_kb_ids(self, WebApiAuth, add_dialog_func, add_dataset_func): + _, dialog_id = add_dialog_func + new_dataset_id = add_dataset_func + payload = { + "dialog_id": dialog_id, + "kb_ids": [new_dataset_id], + "prompt_config": {"system": "You are a helpful assistant with knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}, + } + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert new_dataset_id in res["data"]["kb_ids"], res + + @pytest.mark.p1 + def test_update_llm_settings(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + new_llm_setting = {"model": "gpt-4", "temperature": 0.9, "max_tokens": 2000} + payload = {"dialog_id": dialog_id, "llm_setting": new_llm_setting, "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["llm_setting"]["model"] == "gpt-4", res + assert res["data"]["llm_setting"]["temperature"] == 0.9, res + + @pytest.mark.p1 + def test_update_retrieval_settings(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + payload = { + "dialog_id": dialog_id, + "top_n": 15, + "top_k": 4096, + "similarity_threshold": 0.3, + "vector_similarity_weight": 0.7, + "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}, + } + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["top_n"] == 15, res + assert res["data"]["top_k"] == 4096, res + assert res["data"]["similarity_threshold"] == 0.3, res + assert res["data"]["vector_similarity_weight"] == 0.7, res + + @pytest.mark.p2 + def test_update_nonexistent_dialog(self, WebApiAuth): + fake_dialog_id = "nonexistent_dialog_id" + payload = {"dialog_id": fake_dialog_id, "name": "updated_name", "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 102, res + assert "Dialog not found" in res["message"], res + + @pytest.mark.p2 + def test_update_with_invalid_prompt_config(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + payload = {"dialog_id": dialog_id, "prompt_config": {"system": "You are a helpful assistant.", "parameters": [{"key": "unused_param", "optional": False}]}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 102, res + assert "Parameter 'unused_param' is not used" in res["message"], res + + @pytest.mark.p2 + def test_update_with_knowledge_but_no_kb(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + payload = {"dialog_id": dialog_id, "kb_ids": [], "prompt_config": {"system": "You are a helpful assistant with knowledge: {knowledge}", "parameters": [{"key": "knowledge", "optional": True}]}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 102, res + assert "Please remove `{knowledge}` in system prompt" in res["message"], res + + @pytest.mark.p2 + def test_update_icon(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + new_icon = "🚀" + payload = {"dialog_id": dialog_id, "icon": new_icon, "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["icon"] == new_icon, res + + @pytest.mark.p2 + def test_update_rerank_id(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + payload = {"dialog_id": dialog_id, "rerank_id": "test_rerank_model", "prompt_config": {"system": "You are a helpful assistant.", "parameters": []}} + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + assert res["data"]["rerank_id"] == "test_rerank_model", res + + @pytest.mark.p3 + def test_update_multiple_fields(self, WebApiAuth, add_dialog_func): + _, dialog_id = add_dialog_func + payload = { + "dialog_id": dialog_id, + "name": "multi_update_dialog", + "description": "Updated with multiple fields", + "icon": "🔄", + "top_n": 20, + "similarity_threshold": 0.4, + "prompt_config": {"system": "You are a multi-updated assistant.", "parameters": []}, + } + res = update_dialog(WebApiAuth, payload) + assert res["code"] == 0, res + data = res["data"] + assert data["name"] == "multi_update_dialog", res + assert data["description"] == "Updated with multiple fields", res + assert data["icon"] == "🔄", res + assert data["top_n"] == 20, res + assert data["similarity_threshold"] == 0.4, res diff --git a/web/src/components/knowledge-base-item.tsx b/web/src/components/knowledge-base-item.tsx index 29a3f3dc2..824866231 100644 --- a/web/src/components/knowledge-base-item.tsx +++ b/web/src/components/knowledge-base-item.tsx @@ -67,7 +67,7 @@ export function useDisableDifferenceEmbeddingDataset(name: string) { export function KnowledgeBaseFormField({ showVariable = false, - name = 'dataset_ids', + name = 'kb_ids', required = false, }: { showVariable?: boolean; diff --git a/web/src/hooks/use-chat-request.ts b/web/src/hooks/use-chat-request.ts index 149d26c26..533205a85 100644 --- a/web/src/hooks/use-chat-request.ts +++ b/web/src/hooks/use-chat-request.ts @@ -30,12 +30,10 @@ import { import { useHandleSearchStrChange } from './logic-hooks/use-change-search'; export const enum ChatApiAction { - FetchChatList = 'fetchChatList', - DeleteChat = 'deleteChat', - CreateChat = 'createChat', - UpdateChat = 'updateChat', - PatchChat = 'patchChat', - FetchChat = 'fetchChat', + FetchDialogList = 'fetchDialogList', + RemoveDialog = 'removeDialog', + SetDialog = 'setDialog', + FetchDialog = 'fetchDialog', FetchConversationList = 'fetchConversationList', FetchConversation = 'fetchConversation', FetchConversationManually = 'fetchConversationManually', @@ -62,7 +60,7 @@ export const useGetChatSearchParams = () => { }; }; -export const useFetchChatList = () => { +export const useFetchDialogList = () => { const { searchString, handleInputChange } = useHandleSearchChange(); const { pagination, setPagination } = useGetPaginationWithRouter(); const debouncedSearchString = useDebounce(searchString, { wait: 500 }); @@ -71,19 +69,19 @@ export const useFetchChatList = () => { data, isFetching: loading, refetch, - } = useQuery<{ chats: IDialog[]; total: number }>({ + } = useQuery<{ dialogs: IDialog[]; total: number }>({ queryKey: [ - ChatApiAction.FetchChatList, + ChatApiAction.FetchDialogList, { debouncedSearchString, ...pagination, }, ], - initialData: { chats: [], total: 0 }, + initialData: { dialogs: [], total: 0 }, gcTime: 0, refetchOnWindowFocus: false, queryFn: async () => { - const { data } = await chatService.listChats( + const { data } = await chatService.listDialog( { params: { keywords: debouncedSearchString, @@ -95,7 +93,7 @@ export const useFetchChatList = () => { true, ); - return data?.data ?? { chats: [], total: 0 }; + return data?.data ?? { dialogs: [], total: 0 }; }, }); @@ -117,7 +115,7 @@ export const useFetchChatList = () => { }; }; -export const useDeleteChat = () => { +export const useRemoveDialog = () => { const queryClient = useQueryClient(); const { t } = useTranslation(); @@ -126,23 +124,22 @@ export const useDeleteChat = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: [ChatApiAction.DeleteChat], - mutationFn: async (chatId: string) => { - const { data } = await chatService.deleteChat(chatId); + mutationKey: [ChatApiAction.RemoveDialog], + mutationFn: async (dialogIds: string[]) => { + const { data } = await chatService.removeDialog({ dialogIds }); if (data.code === 0) { - queryClient.invalidateQueries({ - queryKey: [ChatApiAction.FetchChatList], - }); + queryClient.invalidateQueries({ queryKey: ['fetchDialogList'] }); + message.success(t('message.deleted')); } return data.code; }, }); - return { data, loading, deleteChat: mutateAsync }; + return { data, loading, removeDialog: mutateAsync }; }; -export const useCreateChat = () => { +export const useSetDialog = () => { const queryClient = useQueryClient(); const { t } = useTranslation(); @@ -151,96 +148,31 @@ export const useCreateChat = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: [ChatApiAction.CreateChat], - mutationFn: async (params: Record) => { - const { data } = await chatService.createChat(params); + mutationKey: [ChatApiAction.SetDialog], + mutationFn: async (params: Partial) => { + const { data } = await chatService.setDialog(params); if (data.code === 0) { queryClient.invalidateQueries({ exact: false, - queryKey: [ChatApiAction.FetchChatList], + queryKey: [ChatApiAction.FetchDialogList], }); - message.success(t('message.created')); + + queryClient.invalidateQueries({ + queryKey: [ChatApiAction.FetchDialog], + }); + + message.success( + t(`message.${params.dialog_id ? 'modified' : 'created'}`), + ); } return data?.code; }, }); - return { data, loading, createChat: mutateAsync }; + return { data, loading, setDialog: mutateAsync }; }; -export const useUpdateChat = () => { - const queryClient = useQueryClient(); - const { t } = useTranslation(); - - const { - data, - isPending: loading, - mutateAsync, - } = useMutation({ - mutationKey: [ChatApiAction.UpdateChat], - mutationFn: async ({ - chatId, - params, - }: { - chatId: string; - params: Record; - }) => { - const { data } = await chatService.updateChat( - { url: api.updateChat(chatId), data: params }, - true, - ); - if (data.code === 0) { - queryClient.invalidateQueries({ - exact: false, - queryKey: [ChatApiAction.FetchChatList], - }); - queryClient.invalidateQueries({ queryKey: [ChatApiAction.FetchChat] }); - message.success(t('message.modified')); - } - return data?.code; - }, - }); - - return { data, loading, updateChat: mutateAsync }; -}; - -export const usePatchChat = () => { - const queryClient = useQueryClient(); - const { t } = useTranslation(); - - const { - data, - isPending: loading, - mutateAsync, - } = useMutation({ - mutationKey: [ChatApiAction.PatchChat], - mutationFn: async ({ - chatId, - params, - }: { - chatId: string; - params: Record; - }) => { - const { data } = await chatService.patchChat( - { url: api.patchChat(chatId), data: params }, - true, - ); - if (data.code === 0) { - queryClient.invalidateQueries({ - exact: false, - queryKey: [ChatApiAction.FetchChatList], - }); - queryClient.invalidateQueries({ queryKey: [ChatApiAction.FetchChat] }); - message.success(t('message.modified')); - } - return data?.code; - }, - }); - - return { data, loading, patchChat: mutateAsync }; -}; - -export const useFetchChat = () => { +export const useFetchDialog = () => { const { id } = useParams(); const { @@ -248,13 +180,17 @@ export const useFetchChat = () => { isFetching: loading, refetch, } = useQuery({ - queryKey: [ChatApiAction.FetchChat, id], + queryKey: [ChatApiAction.FetchDialog, id], gcTime: 0, initialData: {} as IDialog, enabled: !!id, refetchOnWindowFocus: false, queryFn: async () => { - const { data } = await chatService.getChat(id); + const { data } = await chatService.getDialog( + { params: { dialogId: id } }, + true, + ); + return data?.data ?? ({} as IDialog); }, }); diff --git a/web/src/interfaces/database/chat.ts b/web/src/interfaces/database/chat.ts index cb879456e..e8f5f2b24 100644 --- a/web/src/interfaces/database/chat.ts +++ b/web/src/interfaces/database/chat.ts @@ -14,7 +14,6 @@ export interface PromptConfig { reasoning?: boolean; cross_languages?: Array; tavily_api_key?: string; - toc_enhance?: boolean; } export interface Parameter { @@ -35,8 +34,8 @@ export interface Variable { presence_penalty?: number; temperature?: number; top_p?: number; + llm_id?: string; tenant_llm_id?: string; - model_type?: string; } export interface IDialog { @@ -45,14 +44,14 @@ export interface IDialog { description: string; icon: string; id: string; - dialog_id?: string; - dataset_ids: string[]; + dialog_id: string; + kb_ids: string[]; kb_names: string[]; language: string; llm_id: string; tenant_llm_id?: string; llm_setting: Variable; - llm_setting_type?: string; + llm_setting_type: string; name: string; prompt_config: PromptConfig; prompt_type: string; @@ -64,7 +63,6 @@ export interface IDialog { similarity_threshold: number; top_k: number; top_n: number; - rerank_id?: string; meta_data_filter: MetaDataFilter; } diff --git a/web/src/pages/home/chat-list.tsx b/web/src/pages/home/chat-list.tsx index c6d5661a3..c53ea1708 100644 --- a/web/src/pages/home/chat-list.tsx +++ b/web/src/pages/home/chat-list.tsx @@ -2,7 +2,7 @@ import { HomeCard } from '@/components/home-card'; import { MoreButton } from '@/components/more-button'; import { RenameDialog } from '@/components/rename-dialog'; import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks'; -import { useFetchChatList } from '@/hooks/use-chat-request'; +import { useFetchDialogList } from '@/hooks/use-chat-request'; import { useEffect } from 'react'; import { useTranslation } from 'react-i18next'; import { ChatDropdown } from '../next-chats/chat-dropdown'; @@ -16,7 +16,7 @@ export function ChatList({ setLoading?: (loading: boolean) => void; }) { const { t } = useTranslation(); - const { data, loading } = useFetchChatList(); + const { data, loading } = useFetchDialogList(); const { navigateToChat } = useNavigatePage(); const { @@ -28,12 +28,12 @@ export function ChatList({ chatRenameLoading, } = useRenameChat(); useEffect(() => { - setListLength(data?.chats?.length || 0); + setListLength(data?.dialogs?.length || 0); setLoading?.(loading || false); }, [data, setListLength, loading, setLoading]); return ( <> - {data.chats.slice(0, 10).map((x) => ( + {data.dialogs.slice(0, 10).map((x) => ( = useCallback( @@ -37,8 +37,8 @@ export function ChatDropdown({ ); const handleDelete: MouseEventHandler = useCallback(() => { - deleteChat(chat.id); - }, [chat.id, deleteChat]); + removeDialog([chat.id]); + }, [chat.id, removeDialog]); return ( diff --git a/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx b/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx index e6a908055..14cd0da62 100644 --- a/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx +++ b/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx @@ -4,7 +4,7 @@ import { ScrollArea } from '@/components/ui/scroll-area'; import { Separator } from '@/components/ui/separator'; import { DatasetMetadata } from '@/constants/chat'; import { useSetModalState } from '@/hooks/common-hooks'; -import { useFetchChat, useUpdateChat } from '@/hooks/use-chat-request'; +import { useFetchDialog, useSetDialog } from '@/hooks/use-chat-request'; import { cn } from '@/lib/utils'; import { removeUselessFieldsFromValues, @@ -28,8 +28,8 @@ type ChatSettingsProps = { hasSingleChatBox: boolean }; export function ChatSettings({ hasSingleChatBox }: ChatSettingsProps) { const formSchema = useChatSettingSchema(); - const { data } = useFetchChat(); - const { updateChat, loading } = useUpdateChat(); + const { data } = useFetchDialog(); + const { setDialog, loading } = useSetDialog(); const { id } = useParams(); const { t } = useTranslation(); @@ -45,7 +45,7 @@ export function ChatSettings({ hasSingleChatBox }: ChatSettingsProps) { name: '', icon: '', description: '', - dataset_ids: [], + kb_ids: [], prompt_config: { quote: true, keyword: false, @@ -75,32 +75,22 @@ export function ChatSettings({ hasSingleChatBox }: ChatSettingsProps) { 'llm_setting.', ); - updateChat({ - chatId: id!, - params: { - ...omit(data, [ - 'operator_permission', - 'tenant_id', - 'created_by', - 'create_time', - 'create_date', - 'update_time', - 'update_date', - 'id', - ]), - ...nextValues, - }, + setDialog({ + ...omit(data, 'operator_permission'), + ...nextValues, + dialog_id: id, }); } function onInvalid(errors: any) { - void errors; + console.log('Form validation failed:', errors); } useEffect(() => { const llmSettingEnabledValues = setLLMSettingEnabledValues( data.llm_setting, ); + const nextData = { ...data, ...llmSettingEnabledValues, diff --git a/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx b/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx index ba29383f9..f4d96b999 100644 --- a/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx +++ b/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx @@ -42,7 +42,7 @@ export function useChatSettingSchema() { name: z.string().min(1, { message: t('assistantNameMessage') }), icon: z.string(), description: z.string().optional(), - dataset_ids: z.array(z.string()).min(0, { + kb_ids: z.array(z.string()).min(0, { message: t('knowledgeBasesMessage'), }), prompt_config: promptConfigSchema, diff --git a/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx b/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx index 089a8face..d4efa5539 100644 --- a/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx +++ b/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx @@ -21,9 +21,9 @@ import { useScrollToBottom, } from '@/hooks/logic-hooks'; import { - useFetchChat, + useFetchDialog, useGetChatSearchParams, - usePatchChat, + useSetDialog, } from '@/hooks/use-chat-request'; import { useFetchUserInfo } from '@/hooks/use-user-setting-request'; import { IClientConversation } from '@/interfaces/database/chat'; @@ -102,7 +102,7 @@ const ChatCard = forwardRef(function ChatCard( ref, ) { const { id: dialogId } = useParams(); - const { patchChat } = usePatchChat(); + const { setDialog } = useSetDialog(); const { removeMessageById, derivedMessages, handlePressEnter, sendLoading } = useSendSingleMessage({ @@ -131,7 +131,7 @@ const ChatCard = forwardRef(function ChatCard( const llmId = useWatch({ control: form.control, name: 'llm_id' }); const { data: userInfo } = useFetchUserInfo(); - const { data: currentDialog } = useFetchChat(); + const { data: currentDialog } = useFetchDialog(); useSetDefaultModel(form); @@ -143,15 +143,13 @@ const ChatCard = forwardRef(function ChatCard( const handleApplyConfig = useCallback(() => { const values = form.getValues(); - patchChat({ - chatId: dialogId!, - params: { - ...currentDialog, - llm_id: values.llm_id, - llm_setting: omit(values, 'llm_id'), - }, + setDialog({ + ...currentDialog, + llm_id: values.llm_id, + llm_setting: omit(values, 'llm_id'), + dialog_id: dialogId, }); - }, [currentDialog, dialogId, form, patchChat]); + }, [currentDialog, dialogId, form, setDialog]); useImperativeHandle( ref, diff --git a/web/src/pages/next-chats/chat/chat-box/single-chat-box.tsx b/web/src/pages/next-chats/chat/chat-box/single-chat-box.tsx index 737aaf059..45aa7f25a 100644 --- a/web/src/pages/next-chats/chat/chat-box/single-chat-box.tsx +++ b/web/src/pages/next-chats/chat/chat-box/single-chat-box.tsx @@ -3,7 +3,10 @@ import MessageItem from '@/components/message-item'; import PdfSheet from '@/components/pdf-drawer'; import { useClickDrawer } from '@/components/pdf-drawer/hooks'; import { MessageType } from '@/constants/chat'; -import { useFetchChat, useGetChatSearchParams } from '@/hooks/use-chat-request'; +import { + useFetchDialog, + useGetChatSearchParams, +} from '@/hooks/use-chat-request'; import { useFetchUserInfo } from '@/hooks/use-user-setting-request'; import { IClientConversation } from '@/interfaces/database/chat'; import { buildMessageUuidWithRole } from '@/utils/chat'; @@ -44,7 +47,7 @@ export function SingleChatBox({ setDerivedMessages, } = useSendMessage(controller); const { data: userInfo } = useFetchUserInfo(); - const { data: currentDialog } = useFetchChat(); + const { data: currentDialog } = useFetchDialog(); const { createConversationBeforeUploadDocument } = useCreateConversationBeforeUploadDocument(); const { conversationId } = useGetChatSearchParams(); diff --git a/web/src/pages/next-chats/chat/llm-select-form.tsx b/web/src/pages/next-chats/chat/llm-select-form.tsx index ff9d839f4..e42427dd4 100644 --- a/web/src/pages/next-chats/chat/llm-select-form.tsx +++ b/web/src/pages/next-chats/chat/llm-select-form.tsx @@ -1,7 +1,7 @@ import { LargeModelFormFieldWithoutFilter } from '@/components/large-model-form-field'; import { LlmSettingSchema } from '@/components/llm-setting-items/next'; import { Form } from '@/components/ui/form'; -import { useFetchChat } from '@/hooks/use-chat-request'; +import { useFetchDialog } from '@/hooks/use-chat-request'; import { zodResolver } from '@hookform/resolvers/zod'; import { isEmpty } from 'lodash'; import { useEffect } from 'react'; @@ -10,7 +10,7 @@ import { z } from 'zod'; export function LLMSelectForm() { const FormSchema = z.object(LlmSettingSchema); - const { data } = useFetchChat(); + const { data } = useFetchDialog(); const form = useForm>({ resolver: zodResolver(FormSchema), @@ -25,6 +25,7 @@ export function LLMSelectForm() { if (!isEmpty(data)) { form.reset({ llm_id: data.llm_id, ...data.llm_setting }); } + form.reset(data); }, [data, form]); return ( diff --git a/web/src/pages/next-chats/chat/sessions.tsx b/web/src/pages/next-chats/chat/sessions.tsx index f37976d3e..b4e2b9e68 100644 --- a/web/src/pages/next-chats/chat/sessions.tsx +++ b/web/src/pages/next-chats/chat/sessions.tsx @@ -14,7 +14,7 @@ import { import { SharedFrom } from '@/constants/chat'; import { useSetModalState } from '@/hooks/common-hooks'; import { - useFetchChat, + useFetchDialog, useGetChatSearchParams, useRemoveConversation, } from '@/hooks/use-chat-request'; @@ -48,7 +48,7 @@ export function Sessions({ handleConversationCardClick }: SessionProps) { handleInputChange, searchString, } = useSelectDerivedConversationList(); - const { data } = useFetchChat(); + const { data } = useFetchDialog(); const { visible, switchVisible } = useSetModalState(true); const { removeConversation } = useRemoveConversation(); const { setConversationBoth } = useChatUrlParams(); diff --git a/web/src/pages/next-chats/chat/use-show-internet.ts b/web/src/pages/next-chats/chat/use-show-internet.ts index 64ac48c20..767c31f5b 100644 --- a/web/src/pages/next-chats/chat/use-show-internet.ts +++ b/web/src/pages/next-chats/chat/use-show-internet.ts @@ -1,8 +1,8 @@ -import { useFetchChat } from '@/hooks/use-chat-request'; +import { useFetchDialog } from '@/hooks/use-chat-request'; import { isEmpty } from 'lodash'; export function useShowInternet() { - const { data: currentDialog } = useFetchChat(); + const { data: currentDialog } = useFetchDialog(); return !isEmpty(currentDialog?.prompt_config?.tavily_api_key); } diff --git a/web/src/pages/next-chats/hooks/use-rename-chat.ts b/web/src/pages/next-chats/hooks/use-rename-chat.ts index 705a5ef01..bf8fc2fe6 100644 --- a/web/src/pages/next-chats/hooks/use-rename-chat.ts +++ b/web/src/pages/next-chats/hooks/use-rename-chat.ts @@ -1,8 +1,8 @@ import { useSetModalState } from '@/hooks/common-hooks'; -import { useCreateChat, usePatchChat } from '@/hooks/use-chat-request'; +import { useSetDialog } from '@/hooks/use-chat-request'; import { useFetchTenantInfo } from '@/hooks/use-user-setting-request'; import { IDialog } from '@/interfaces/database/chat'; -import { isEmpty } from 'lodash'; +import { isEmpty, omit } from 'lodash'; import { useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; @@ -13,8 +13,7 @@ export const useRenameChat = () => { hideModal: hideChatRenameModal, showModal: showChatRenameModal, } = useSetModalState(); - const { createChat, loading: createLoading } = useCreateChat(); - const { patchChat, loading: patchLoading } = usePatchChat(); + const { setDialog, loading } = useSetDialog(); const { t } = useTranslation(); const tenantInfo = useFetchTenantInfo(); @@ -24,7 +23,6 @@ export const useRenameChat = () => { icon: '', language: 'English', description: '', - dataset_ids: [], prompt_config: { empty_response: '', prologue: t('chat.setAnOpenerInitial'), @@ -43,28 +41,28 @@ export const useRenameChat = () => { similarity_threshold: 0.2, vector_similarity_weight: 0.3, top_n: 8, - top_k: 1024, }), [t, tenantInfo.data.llm_id], ); const onChatRenameOk = useCallback( async (name: string) => { - let ret: number | undefined; - if (isEmpty(chat)) { - ret = await createChat({ ...InitialData, name }); - } else { - ret = await patchChat({ - chatId: chat.id, - params: { name }, - }); - } + const nextChat = { + ...(isEmpty(chat) + ? InitialData + : { + ...omit(chat, 'nickname', 'tenant_avatar', 'operator_permission'), + dialog_id: chat.id, + }), + name, + }; + const ret = await setDialog(nextChat); if (ret === 0) { hideChatRenameModal(); } }, - [chat, InitialData, createChat, patchChat, hideChatRenameModal], + [chat, InitialData, setDialog, hideChatRenameModal], ); const handleShowChatRenameModal = useCallback( @@ -85,7 +83,7 @@ export const useRenameChat = () => { }, [hideChatRenameModal]); return { - chatRenameLoading: createLoading || patchLoading, + chatRenameLoading: loading, initialChatName: chat?.name, onChatRenameOk, chatRenameVisible, diff --git a/web/src/pages/next-chats/hooks/use-select-conversation-list.ts b/web/src/pages/next-chats/hooks/use-select-conversation-list.ts index b553dcbc4..89fd5a447 100644 --- a/web/src/pages/next-chats/hooks/use-select-conversation-list.ts +++ b/web/src/pages/next-chats/hooks/use-select-conversation-list.ts @@ -1,8 +1,8 @@ import { MessageType } from '@/constants/chat'; import { useTranslate } from '@/hooks/common-hooks'; import { - useFetchChatList, useFetchConversationList, + useFetchDialogList, } from '@/hooks/use-chat-request'; import { IConversation } from '@/interfaces/database/chat'; import { generateConversationId } from '@/utils/chat'; @@ -12,10 +12,10 @@ import { useChatUrlParams } from './use-chat-url'; export const useFindPrologueFromDialogList = () => { const { id: dialogId } = useParams(); - const { data } = useFetchChatList(); + const { data } = useFetchDialogList(); const prologue = useMemo(() => { - return data.chats.find((x) => x.id === dialogId)?.prompt_config?.prologue; + return data.dialogs.find((x) => x.id === dialogId)?.prompt_config.prologue; }, [dialogId, data]); return prologue; diff --git a/web/src/pages/next-chats/index.tsx b/web/src/pages/next-chats/index.tsx index fb49f0024..a6d764c53 100644 --- a/web/src/pages/next-chats/index.tsx +++ b/web/src/pages/next-chats/index.tsx @@ -5,7 +5,7 @@ import ListFilterBar from '@/components/list-filter-bar'; import { RenameDialog } from '@/components/rename-dialog'; import { Button } from '@/components/ui/button'; import { RAGFlowPagination } from '@/components/ui/ragflow-pagination'; -import { useFetchChatList } from '@/hooks/use-chat-request'; +import { useFetchDialogList } from '@/hooks/use-chat-request'; import { pick } from 'lodash'; import { Plus } from 'lucide-react'; import { useCallback, useEffect } from 'react'; @@ -16,7 +16,7 @@ import { useRenameChat } from './hooks/use-rename-chat'; export default function ChatList() { const { data, setPagination, pagination, handleInputChange, searchString } = - useFetchChatList(); + useFetchDialogList(); const { t } = useTranslation(); const { initialChatName, @@ -50,7 +50,7 @@ export default function ChatList() { return ( <> - {data.chats?.length || searchString ? ( + {data.dialogs?.length || searchString ? (
- {data.chats?.length ? ( + {data.dialogs?.length ? ( <> - {data.chats.map((x) => ( + {data.dialogs.map((x) => ( `${ExternalApi}${api_host}/chats/${chatId}`, - updateChat: (chatId: string) => `${ExternalApi}${api_host}/chats/${chatId}`, - patchChat: (chatId: string) => `${ExternalApi}${api_host}/chats/${chatId}`, - deleteChat: (chatId: string) => `${ExternalApi}${api_host}/chats/${chatId}`, - bulkDeleteChats: `${ExternalApi}${api_host}/chats`, + setDialog: `${api_host}/dialog/set`, + getDialog: `${api_host}/dialog/get`, + removeDialog: `${api_host}/dialog/rm`, + listDialog: `${api_host}/dialog/list`, setConversation: `${api_host}/conversation/set`, getConversation: `${api_host}/conversation/get`, getConversationSSE: (dialogId: string) => @@ -159,6 +156,7 @@ export default { uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`, // next chat + listNextDialog: `${api_host}/dialog/next`, fetchExternalChatInfo: (id: string) => `${ExternalApi}${api_host}/chatbots/${id}/info`, diff --git a/web/src/utils/llm-util.ts b/web/src/utils/llm-util.ts index 6086e8fac..b267642ed 100644 --- a/web/src/utils/llm-util.ts +++ b/web/src/utils/llm-util.ts @@ -79,7 +79,7 @@ const modelParamMap: ModelParamMap = { // API endpoint whitelist - only these endpoints will have tenant parameters added const API_WHITELIST = [ '/v1/user/set_tenant_info', - '/api/v1/chats', + '/v1/dialog/set', '/v1/canvas/set', '/v1/canvas/setting', '/api/v1/searches/',