diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index cf64d8c4f..965f4ffaa 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -15,7 +15,6 @@ # import json import time -import uuid from typing import Any, List, Optional import multiprocessing as mp from concurrent.futures import ProcessPoolExecutor, as_completed @@ -1130,7 +1129,7 @@ class RAGFlowClient: def _list_chat_sessions(self, dialog_id): """List all sessions (conversations) for a given dialog.""" - response = self.http_client.request("GET", f"/conversation/list?dialog_id={dialog_id}", use_api_base=False, + response = self.http_client.request("GET", f"/chats/{dialog_id}/conversations", use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: @@ -1146,14 +1145,9 @@ class RAGFlowClient: dialog_id = self._get_chat_id_by_name(chat_name) if dialog_id is None: return - conversation_id = str(uuid.uuid4()).replace("-", "") - payload = { - "conversation_id": conversation_id, - "is_new": True, - "dialog_id": dialog_id - } - response = self.http_client.request("POST", "/conversation/set", json_body=payload, use_api_base=False, - auth_kind="web") + payload = {"name": "New conversation"} + response = self.http_client.request("POST", f"/chats/{dialog_id}/conversations", json_body=payload, + use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: print(f"Success to create chat session for chat: {chat_name}") @@ -1179,9 +1173,9 @@ class RAGFlowClient: if not to_drop_session_ids: print(f"Chat session '{session_id}' not found in chat '{chat_name}'") return - payload = {"conversation_ids": to_drop_session_ids} - response = self.http_client.request("POST", "/conversation/rm", json_body=payload, use_api_base=False, - auth_kind="web") + payload = {"ids": to_drop_session_ids} + response = self.http_client.request("DELETE", f"/chats/{dialog_id}/conversations", json_body=payload, + use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: print(f"Success to drop chat session '{session_id}' from chat: {chat_name}") diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py deleted file mode 100644 index f7518d225..000000000 --- a/api/apps/conversation_app.py +++ /dev/null @@ -1,479 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import json -import os -import re -import logging -from copy import deepcopy -import tempfile -from quart import Response, request -from api.apps import current_user, login_required -from api.db.db_models import APIToken -from api.db.services.conversation_service import ConversationService, structure_answer -from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap -from api.db.services.llm_service import LLMBundle -from api.db.services.search_service import SearchService -from api.db.services.tenant_llm_service import TenantLLMService -from api.db.services.user_service import UserTenantService -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type -from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request -from rag.prompts.template import load_prompt -from rag.prompts.generator import chunks_format -from common.constants import RetCode, LLMType - - -@manager.route("/set", methods=["POST"]) # noqa: F821 -@login_required -async def set_conversation(): - req = await get_request_json() - conv_id = req.get("conversation_id") - is_new = req.get("is_new") - name = req.get("name", "New conversation") - req["user_id"] = current_user.id - - if len(name) > 255: - name = name[0:255] - - del req["is_new"] - if not is_new: - del req["conversation_id"] - try: - if not ConversationService.update_by_id(conv_id, req): - return get_data_error_result(message="Conversation not found!") - e, conv = ConversationService.get_by_id(conv_id) - if not e: - return get_data_error_result(message="Fail to update a conversation!") - conv = conv.to_dict() - return get_json_result(data=conv) - except Exception as e: - return server_error_response(e) - - try: - e, dia = DialogService.get_by_id(req["dialog_id"]) - if not e: - return get_data_error_result(message="Dialog not found") - conv = { - "id": conv_id, - "dialog_id": req["dialog_id"], - "name": name, - "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}], - "user_id": current_user.id, - "reference": [], - } - ConversationService.save(**conv) - return get_json_result(data=conv) - except Exception as e: - return server_error_response(e) - - -@manager.route("/get", methods=["GET"]) # noqa: F821 -@login_required -async def get(): - conv_id = request.args["conversation_id"] - try: - e, conv = ConversationService.get_by_id(conv_id) - if not e: - return get_data_error_result(message="Conversation not found!") - tenants = UserTenantService.query(user_id=current_user.id) - for tenant in tenants: - dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id) - if dialog and len(dialog) > 0: - avatar = dialog[0].icon - break - else: - return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=RetCode.OPERATING_ERROR) - - for ref in conv.reference: - if isinstance(ref, list): - continue - ref["chunks"] = chunks_format(ref) - - conv = conv.to_dict() - conv["avatar"] = avatar - return get_json_result(data=conv) - except Exception as e: - return server_error_response(e) - - -@manager.route("/getsse/", methods=["GET"]) # type: ignore # noqa: F821 -def getsse(dialog_id): - token = request.headers.get("Authorization").split() - if len(token) != 2: - return get_data_error_result(message='Authorization is not valid!') - token = token[1] - objs = APIToken.query(beta=token) - if not objs: - return get_data_error_result(message='Authentication error: API key is invalid!"') - try: - e, conv = DialogService.get_by_id(dialog_id) - if not e: - return get_data_error_result(message="Dialog not found!") - conv = conv.to_dict() - conv["avatar"] = conv["icon"] - del conv["icon"] - return get_json_result(data=conv) - except Exception as e: - return server_error_response(e) - - -@manager.route("/rm", methods=["POST"]) # noqa: F821 -@login_required -async def rm(): - req = await get_request_json() - conv_ids = req["conversation_ids"] - try: - for cid in conv_ids: - exist, conv = ConversationService.get_by_id(cid) - if not exist: - return get_data_error_result(message="Conversation not found!") - tenants = UserTenantService.query(user_id=current_user.id) - for tenant in tenants: - if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id): - break - else: - return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=RetCode.OPERATING_ERROR) - ConversationService.delete_by_id(cid) - return get_json_result(data=True) - except Exception as e: - return server_error_response(e) - - -@manager.route("/list", methods=["GET"]) # noqa: F821 -@login_required -async def list_conversation(): - dialog_id = request.args["dialog_id"] - try: - if not DialogService.query(tenant_id=current_user.id, id=dialog_id): - return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=RetCode.OPERATING_ERROR) - convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True) - - convs = [d.to_dict() for d in convs] - return get_json_result(data=convs) - except Exception as e: - return server_error_response(e) - - -@manager.route("/completion", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("conversation_id", "messages") -async def completion(): - req = await get_request_json() - msg = [] - for m in req["messages"]: - if m["role"] == "system": - continue - if m["role"] == "assistant" and not msg: - continue - msg.append(m) - message_id = msg[-1].get("id") - chat_model_id = req.get("llm_id", "") - req.pop("llm_id", None) - - chat_model_config = {} - for model_config in [ - "temperature", - "top_p", - "frequency_penalty", - "presence_penalty", - "max_tokens", - ]: - config = req.get(model_config) - if config: - chat_model_config[model_config] = config - - try: - e, conv = ConversationService.get_by_id(req["conversation_id"]) - if not e: - return get_data_error_result(message="Conversation not found!") - conv.message = deepcopy(req["messages"]) - e, dia = DialogService.get_by_id(conv.dialog_id) - if not e: - return get_data_error_result(message="Dialog not found!") - del req["conversation_id"] - del req["messages"] - - if not conv.reference: - conv.reference = [] - conv.reference = [r for r in conv.reference if r] - conv.reference.append({"chunks": [], "doc_aggs": []}) - - if chat_model_id: - if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id): - req.pop("chat_model_id", None) - req.pop("chat_model_config", None) - return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") - dia.llm_id = chat_model_id - dia.llm_setting = chat_model_config - - is_embedded = bool(chat_model_id) - # Remove stream from req to avoid duplicate argument error - stream_mode = req.pop("stream", True) - async def stream(): - nonlocal dia, msg, req, conv - try: - async for ans in async_chat(dia, msg, True, **req): - ans = structure_answer(conv, ans, message_id, conv.id) - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" - if not is_embedded: - ConversationService.update_by_id(conv.id, conv.to_dict()) - except Exception as e: - logging.exception(e) - yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - if stream_mode: - resp = Response(stream(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - else: - answer = None - async for ans in async_chat(dia, msg, **req): - answer = structure_answer(conv, ans, message_id, conv.id) - if not is_embedded: - ConversationService.update_by_id(conv.id, conv.to_dict()) - break - return get_json_result(data=answer) - except Exception as e: - return server_error_response(e) - -@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821 -@login_required -async def sequence2txt(): - req = await request.form - stream_mode = req.get("stream", "false").lower() == "true" - files = await request.files - if "file" not in files: - return get_data_error_result(message="Missing 'file' in multipart form-data") - - uploaded = files["file"] - - ALLOWED_EXTS = { - ".wav", ".mp3", ".m4a", ".aac", - ".flac", ".ogg", ".webm", - ".opus", ".wma" - } - - filename = uploaded.filename or "" - suffix = os.path.splitext(filename)[-1].lower() - if suffix not in ALLOWED_EXTS: - return get_data_error_result(message= - f"Unsupported audio format: {suffix}. " - f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}" - ) - fd, temp_audio_path = tempfile.mkstemp(suffix=suffix) - os.close(fd) - await uploaded.save(temp_audio_path) - - try: - default_asr_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.SPEECH2TEXT) - except Exception as e: - return get_data_error_result(message=str(e)) - - asr_mdl=LLMBundle(current_user.id, default_asr_model_config) - if not stream_mode: - text = asr_mdl.transcription(temp_audio_path) - try: - os.remove(temp_audio_path) - except Exception as e: - logging.error(f"Failed to remove temp audio file: {str(e)}") - return get_json_result(data={"text": text}) - async def event_stream(): - try: - for evt in asr_mdl.stream_transcription(temp_audio_path): - yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n" - except Exception as e: - err = {"event": "error", "text": str(e)} - yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n" - finally: - try: - os.remove(temp_audio_path) - except Exception as e: - logging.error(f"Failed to remove temp audio file: {str(e)}") - - return Response(event_stream(), content_type="text/event-stream") - -@manager.route("/tts", methods=["POST"]) # noqa: F821 -@login_required -async def tts(): - req = await get_request_json() - text = req["text"] - - try: - default_tts_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.TTS) - except Exception as e: - return get_data_error_result(message=str(e)) - - tts_mdl = LLMBundle(current_user.id, default_tts_model_config) - - def stream_audio(): - try: - for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): - for chunk in tts_mdl.tts(txt): - yield chunk - except Exception as e: - yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8") - - resp = Response(stream_audio(), mimetype="audio/mpeg") - resp.headers.add_header("Cache-Control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - - return resp - - -@manager.route("/delete_msg", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("conversation_id", "message_id") -async def delete_msg(): - req = await get_request_json() - e, conv = ConversationService.get_by_id(req["conversation_id"]) - if not e: - return get_data_error_result(message="Conversation not found!") - - conv = conv.to_dict() - for i, msg in enumerate(conv["message"]): - if req["message_id"] != msg.get("id", ""): - continue - assert conv["message"][i + 1]["id"] == req["message_id"] - conv["message"].pop(i) - conv["message"].pop(i) - conv["reference"].pop(max(0, i // 2 - 1)) - break - - ConversationService.update_by_id(conv["id"], conv) - return get_json_result(data=conv) - - -@manager.route("/thumbup", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("conversation_id", "message_id") -async def thumbup(): - req = await get_request_json() - e, conv = ConversationService.get_by_id(req["conversation_id"]) - if not e: - return get_data_error_result(message="Conversation not found!") - up_down = req.get("thumbup") - feedback = req.get("feedback", "") - conv = conv.to_dict() - for i, msg in enumerate(conv["message"]): - if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant": - if up_down: - msg["thumbup"] = True - if "feedback" in msg: - del msg["feedback"] - else: - msg["thumbup"] = False - if feedback: - msg["feedback"] = feedback - break - - ConversationService.update_by_id(conv["id"], conv) - return get_json_result(data=conv) - - -@manager.route("/ask", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("question", "kb_ids") -async def ask_about(): - req = await get_request_json() - uid = current_user.id - - search_id = req.get("search_id", "") - search_app = None - search_config = {} - if search_id: - search_app = SearchService.get_detail(search_id) - if search_app: - search_config = search_app.get("search_config", {}) - - async def stream(): - nonlocal req, uid - try: - async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" - except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - resp = Response(stream(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - -@manager.route("/mindmap", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("question", "kb_ids") -async def mindmap(): - req = await get_request_json() - search_id = req.get("search_id", "") - search_app = SearchService.get_detail(search_id) if search_id else {} - search_config = search_app.get("search_config", {}) if search_app else {} - kb_ids = search_config.get("kb_ids", []) - kb_ids.extend(req["kb_ids"]) - kb_ids = list(set(kb_ids)) - - mind_map = await gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config) - if "error" in mind_map: - return server_error_response(Exception(mind_map["error"])) - return get_json_result(data=mind_map) - - -@manager.route("/related_questions", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("question") -async def related_questions(): - req = await get_request_json() - - search_id = req.get("search_id", "") - search_config = {} - if search_id: - if search_app := SearchService.get_detail(search_id): - search_config = search_app.get("search_config", {}) - - question = req["question"] - - chat_id = search_config.get("chat_id", "") - if chat_id: - chat_model_config = get_model_config_by_type_and_name(current_user.id, LLMType.CHAT, chat_id) - else: - chat_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.CHAT) - chat_mdl = LLMBundle(current_user.id, chat_model_config) - - gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) - if "parameter" in gen_conf: - del gen_conf["parameter"] - prompt = load_prompt("related_question") - ans = await chat_mdl.async_chat( - prompt, - [ - { - "role": "user", - "content": f""" -Keywords: {question} -Related search terms: - """, - } - ], - gen_conf, - ) - return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index dedf2b917..6cda09501 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -14,13 +14,25 @@ # limitations under the License. # +import json +import logging +import os +import re +import tempfile from copy import deepcopy -from quart import request +from quart import Response, request from api.apps import current_user, login_required -from api.db.services.dialog_service import DialogService +from api.db.joint_services.tenant_model_service import ( + get_model_config_by_type_and_name, + get_tenant_default_model_by_type, +) +from api.db.services.conversation_service import ConversationService, structure_answer +from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import LLMBundle +from api.db.services.search_service import SearchService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import ( @@ -29,10 +41,13 @@ from api.utils.api_utils import ( get_json_result, get_request_json, server_error_response, + validate_request, ) from api.utils.tenant_utils import ensure_tenant_model_id_for_params -from common.constants import RetCode, StatusEnum +from common.constants import LLMType, RetCode, StatusEnum from common.misc_utils import get_uuid +from rag.prompts.generator import chunks_format +from rag.prompts.template import load_prompt _DEFAULT_PROMPT_CONFIG = { "system": ( @@ -95,6 +110,13 @@ def _validate_name(name, *, required=True): return name, None +def _build_session_response(conv: dict) -> dict: + conv = dict(conv) + conv["chat_id"] = conv.pop("dialog_id", conv.get("chat_id")) + conv["messages"] = conv.pop("message", conv.get("messages", [])) + return conv + + def _ensure_owned_chat(chat_id): return DialogService.query( tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value @@ -567,3 +589,458 @@ async def bulk_delete_chats(): return get_data_error_result(message="; ".join(all_errors)) return get_json_result(data={"success_count": success_count}) + + +@manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 +@login_required +async def create_session(chat_id): + if not _ensure_owned_chat(chat_id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + try: + req = await get_request_json() + ok, dia = DialogService.get_by_id(chat_id) + if not ok: + return get_data_error_result(message="Chat not found!") + name = req.get("name", "New session") + if not isinstance(name, str) or not name.strip(): + return get_data_error_result(message="`name` can not be empty.") + name = name.strip()[:255] + conv = { + "id": get_uuid(), + "dialog_id": chat_id, + "name": name, + "message": [{"role": "assistant", "content": dia.prompt_config.get("prologue", "")}], + "user_id": req.get("user_id", current_user.id), + "reference": [], + } + ConversationService.save(**conv) + ok, conv_obj = ConversationService.get_by_id(conv["id"]) + if not ok: + return get_data_error_result(message="Fail to create a session!") + return get_json_result(data=_build_session_response(conv_obj.to_dict())) + except Exception as ex: + return server_error_response(ex) + + +@manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 +@login_required +def list_sessions(chat_id): + try: + if not _ensure_owned_chat(chat_id): + return get_json_result( + data=False, + message="No authorization.", + code=RetCode.AUTHENTICATION_ERROR, + ) + page_number = int(request.args.get("page", 1)) + items_per_page = int(request.args.get("page_size", 30)) + orderby = request.args.get("orderby", "create_time") + desc = request.args.get("desc", "true").lower() != "false" + session_id = request.args.get("id") + name = request.args.get("name") + user_id = request.args.get("user_id") + convs = ConversationService.get_list( + chat_id, page_number, items_per_page, orderby, desc, session_id, name, user_id + ) + if items_per_page == 0: + convs = [] + return get_json_result(data=[_build_session_response(c) for c in convs]) + except Exception as ex: + return server_error_response(ex) + + +@manager.route("/chats//sessions/", methods=["GET"]) # noqa: F821 +@login_required +async def get_session(chat_id, session_id): + if not _ensure_owned_chat(chat_id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + try: + ok, conv = ConversationService.get_by_id(session_id) + if not ok: + return get_data_error_result(message="Session not found!") + if conv.dialog_id != chat_id: + return get_data_error_result(message="Session does not belong to this chat!") + dialog = _ensure_owned_chat(chat_id) + avatar = dialog[0].icon if dialog else "" + for ref in conv.reference: + if isinstance(ref, list): + continue + ref["chunks"] = chunks_format(ref) + result = _build_session_response(conv.to_dict()) + result["avatar"] = avatar + return get_json_result(data=result) + except Exception as ex: + return server_error_response(ex) + + +@manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821 +@login_required +async def update_session(chat_id, session_id): + if not _ensure_owned_chat(chat_id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + try: + req = await get_request_json() + if not ConversationService.query(id=session_id, dialog_id=chat_id): + return get_data_error_result(message="Session not found!") + if "message" in req or "messages" in req: + return get_data_error_result(message="`messages` cannot be changed.") + if "reference" in req: + return get_data_error_result(message="`reference` cannot be changed.") + name = req.get("name") + if name is not None: + if not isinstance(name, str) or not name.strip(): + return get_data_error_result(message="`name` can not be empty.") + req["name"] = name.strip()[:255] + update_fields = {k: v for k, v in req.items() if k not in {"id", "dialog_id", "chat_id", "user_id"}} + if not ConversationService.update_by_id(session_id, update_fields): + return get_data_error_result(message="Session not found!") + ok, conv = ConversationService.get_by_id(session_id) + if not ok: + return get_data_error_result(message="Fail to update a session!") + return get_json_result(data=_build_session_response(conv.to_dict())) + except Exception as ex: + return server_error_response(ex) + + +@manager.route("/chats//sessions", methods=["DELETE"]) # noqa: F821 +@login_required +async def delete_sessions(chat_id): + if not _ensure_owned_chat(chat_id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + try: + req = await get_request_json() + if not req: + return get_json_result(data={}) + + session_ids = req.get("ids") + if not session_ids: + if req.get("delete_all") is True: + session_ids = [conv.id for conv in ConversationService.query(dialog_id=chat_id)] + if not session_ids: + return get_json_result(data={}) + else: + return get_json_result(data={}) + unique_ids, duplicate_messages = check_duplicate_ids(session_ids, "session") + errors = [] + success_count = 0 + for sid in unique_ids: + if not ConversationService.query(id=sid, dialog_id=chat_id): + errors.append(f"The chat doesn't own the session {sid}") + continue + ConversationService.delete_by_id(sid) + success_count += 1 + all_errors = errors + duplicate_messages + if all_errors: + if success_count > 0: + return get_json_result( + data={"success_count": success_count, "errors": all_errors}, + message=f"Partially deleted {success_count} sessions with {len(all_errors)} errors", + ) + return get_data_error_result(message="; ".join(all_errors)) + return get_json_result(data=True) + except Exception as ex: + return server_error_response(ex) + + +@manager.route("/chats//sessions//messages/", methods=["DELETE"]) # noqa: F821 +@login_required +async def delete_session_message(chat_id, session_id, msg_id): + if not _ensure_owned_chat(chat_id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + try: + ok, conv = ConversationService.get_by_id(session_id) + if not ok or conv.dialog_id != chat_id: + return get_data_error_result(message="Session not found!") + conv = conv.to_dict() + for i, msg in enumerate(conv["message"]): + if msg_id != msg.get("id", ""): + continue + assert conv["message"][i + 1]["id"] == msg_id + conv["message"].pop(i) + conv["message"].pop(i) + conv["reference"].pop(max(0, i // 2 - 1)) + break + ConversationService.update_by_id(conv["id"], conv) + return get_json_result(data=_build_session_response(conv)) + except Exception as ex: + return server_error_response(ex) + + +@manager.route("/chats//sessions//messages//feedback", methods=["PUT"]) # noqa: F821 +@login_required +async def update_message_feedback(chat_id, session_id, msg_id): + if not _ensure_owned_chat(chat_id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + try: + req = await get_request_json() + ok, conv = ConversationService.get_by_id(session_id) + if not ok or conv.dialog_id != chat_id: + return get_data_error_result(message="Session not found!") + up_down = req.get("thumbup") + feedback = req.get("feedback", "") + conv = conv.to_dict() + for msg in conv["message"]: + if msg_id == msg.get("id", "") and msg.get("role", "") == "assistant": + if up_down: + msg["thumbup"] = True + msg.pop("feedback", None) + else: + msg["thumbup"] = False + if feedback: + msg["feedback"] = feedback + break + ConversationService.update_by_id(conv["id"], conv) + return get_json_result(data=_build_session_response(conv)) + except Exception as ex: + return server_error_response(ex) + + +@manager.route("/chats/tts", methods=["POST"]) # noqa: F821 +@login_required +async def tts(): + req = await get_request_json() + text = req["text"] + + try: + default_tts_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.TTS) + except Exception as e: + return get_data_error_result(message=str(e)) + + tts_mdl = LLMBundle(current_user.id, default_tts_model_config) + + def stream_audio(): + try: + for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): + for chunk in tts_mdl.tts(txt): + yield chunk + except Exception as e: + yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8") + + resp = Response(stream_audio(), mimetype="audio/mpeg") + resp.headers.add_header("Cache-Control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + return resp + + +@manager.route("/chats/transcriptions", methods=["POST"]) # noqa: F821 +@login_required +async def transcriptions(): + req = await request.form + stream_mode = req.get("stream", "false").lower() == "true" + files = await request.files + if "file" not in files: + return get_data_error_result(message="Missing 'file' in multipart form-data") + + uploaded = files["file"] + + ALLOWED_EXTS = { + ".wav", ".mp3", ".m4a", ".aac", + ".flac", ".ogg", ".webm", + ".opus", ".wma", + } + + filename = uploaded.filename or "" + suffix = os.path.splitext(filename)[-1].lower() + if suffix not in ALLOWED_EXTS: + return get_data_error_result( + message=f"Unsupported audio format: {suffix}. Allowed: {', '.join(sorted(ALLOWED_EXTS))}" + ) + + fd, temp_audio_path = tempfile.mkstemp(suffix=suffix) + os.close(fd) + await uploaded.save(temp_audio_path) + + try: + default_asr_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.SPEECH2TEXT) + except Exception as e: + return get_data_error_result(message=str(e)) + + asr_mdl = LLMBundle(current_user.id, default_asr_model_config) + if not stream_mode: + text = asr_mdl.transcription(temp_audio_path) + try: + os.remove(temp_audio_path) + except Exception as e: + logging.error(f"Failed to remove temp audio file: {str(e)}") + return get_json_result(data={"text": text}) + + async def event_stream(): + try: + for evt in asr_mdl.stream_transcription(temp_audio_path): + yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n" + except Exception as e: + err = {"event": "error", "text": str(e)} + yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n" + finally: + try: + os.remove(temp_audio_path) + except Exception as e: + logging.error(f"Failed to remove temp audio file: {str(e)}") + + return Response(event_stream(), content_type="text/event-stream") + + +@manager.route("/chats/mindmap", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("question", "kb_ids") +async def mindmap(): + req = await get_request_json() + search_id = req.get("search_id", "") + search_app = SearchService.get_detail(search_id) if search_id else {} + search_config = search_app.get("search_config", {}) if search_app else {} + kb_ids = search_config.get("kb_ids", []) + kb_ids.extend(req["kb_ids"]) + kb_ids = list(set(kb_ids)) + + mind_map = await gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config) + if "error" in mind_map: + return server_error_response(Exception(mind_map["error"])) + return get_json_result(data=mind_map) + + +@manager.route("/chats/related_questions", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("question") +async def related_questions(): + req = await get_request_json() + + search_id = req.get("search_id", "") + search_config = {} + if search_id: + if search_app := SearchService.get_detail(search_id): + search_config = search_app.get("search_config", {}) + + question = req["question"] + + chat_id = search_config.get("chat_id", "") + if chat_id: + chat_model_config = get_model_config_by_type_and_name(current_user.id, LLMType.CHAT, chat_id) + else: + chat_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.CHAT) + chat_mdl = LLMBundle(current_user.id, chat_model_config) + + gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) + if "parameter" in gen_conf: + del gen_conf["parameter"] + prompt = load_prompt("related_question") + ans = await chat_mdl.async_chat( + prompt, + [ + { + "role": "user", + "content": f"\nKeywords: {question}\nRelated search terms:\n ", + } + ], + gen_conf, + ) + return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) + + +@manager.route("/chats//sessions//completions", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("messages") +async def session_completion(chat_id, session_id): + req = await get_request_json() + msg = [] + for m in req["messages"]: + if m["role"] == "system": + continue + if m["role"] == "assistant" and not msg: + continue + msg.append(m) + message_id = msg[-1].get("id") if msg else None + chat_model_id = req.pop("llm_id", "") + + chat_model_config = {} + for model_config in ["temperature", "top_p", "frequency_penalty", "presence_penalty", "max_tokens"]: + config = req.get(model_config) + if config: + chat_model_config[model_config] = config + + try: + e, conv = ConversationService.get_by_id(session_id) + if not e: + return get_data_error_result(message="Session not found!") + if conv.dialog_id != chat_id: + return get_data_error_result(message="Session does not belong to this chat!") + conv.message = deepcopy(req["messages"]) + e, dia = DialogService.get_by_id(chat_id) + if not e: + return get_data_error_result(message="Chat not found!") + del req["messages"] + + if not conv.reference: + conv.reference = [] + conv.reference = [r for r in conv.reference if r] + conv.reference.append({"chunks": [], "doc_aggs": []}) + + if chat_model_id: + if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id): + return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") + dia.llm_id = chat_model_id + dia.llm_setting = chat_model_config + + is_embedded = bool(chat_model_id) + stream_mode = req.pop("stream", True) + + async def stream(): + nonlocal dia, msg, req, conv + try: + async for ans in async_chat(dia, msg, True, **req): + ans = structure_answer(conv, ans, message_id, conv.id) + yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" + if not is_embedded: + ConversationService.update_by_id(conv.id, conv.to_dict()) + except Exception as ex: + logging.exception(ex) + yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + + if stream_mode: + resp = Response(stream(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + answer = None + async for ans in async_chat(dia, msg, **req): + answer = structure_answer(conv, ans, message_id, conv.id) + if not is_embedded: + ConversationService.update_by_id(conv.id, conv.to_dict()) + break + return get_json_result(data=answer) + except Exception as ex: + return server_error_response(ex) + + +@manager.route("/chats/ask", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("question", "kb_ids") +async def ask(): + req = await get_request_json() + uid = current_user.id + + search_id = req.get("search_id", "") + search_config = {} + if search_id: + if search_app := SearchService.get_detail(search_id): + search_config = search_app.get("search_config", {}) + + async def stream(): + nonlocal req, uid + try: + async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): + yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" + except Exception as ex: + yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + + resp = Response(stream(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 7c3395d3a..82e048ff1 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -54,35 +54,6 @@ from common.constants import RetCode, LLMType, StatusEnum from common import settings -@manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 -@token_required -async def create(tenant_id, chat_id): - req = await get_request_json() - req["dialog_id"] = chat_id - dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) - if not dia: - return get_error_data_result(message="You do not own the assistant.") - conv = { - "id": get_uuid(), - "dialog_id": req["dialog_id"], - "name": req.get("name", "New session"), - "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}], - "user_id": req.get("user_id", ""), - "reference": [], - } - if not conv.get("name"): - return get_error_data_result(message="`name` can not be empty.") - ConversationService.save(**conv) - e, conv = ConversationService.get_by_id(conv["id"]) - if not e: - return get_error_data_result(message="Fail to create a session!") - conv = conv.to_dict() - conv["messages"] = conv.pop("message") - conv["chat_id"] = conv.pop("dialog_id") - del conv["reference"] - return get_result(data=conv) - - @manager.route("/agents//sessions", methods=["POST"]) # noqa: F821 @token_required async def create_agent_session(tenant_id, agent_id): @@ -121,28 +92,6 @@ async def create_agent_session(tenant_id, agent_id): return get_result(data=conv) -@manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821 -@token_required -async def update(tenant_id, chat_id, session_id): - req = await get_request_json() - req["dialog_id"] = chat_id - conv_id = session_id - conv = ConversationService.query(id=conv_id, dialog_id=chat_id) - if not conv: - return get_error_data_result(message="Session does not exist") - if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): - return get_error_data_result(message="You do not own the session") - if "message" in req or "messages" in req: - return get_error_data_result(message="`message` can not be change") - if "reference" in req: - return get_error_data_result(message="`reference` can not be change") - if "name" in req and not req.get("name"): - return get_error_data_result(message="`name` can not be empty.") - if not ConversationService.update_by_id(conv_id, req): - return get_error_data_result(message="Session updates error") - return get_result() - - @manager.route("/chats//completions", methods=["POST"]) # noqa: F821 @token_required async def chat_completion(tenant_id, chat_id): @@ -632,60 +581,6 @@ async def agent_completions(tenant_id, agent_id): return get_result(data=final_ans) -@manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 -@token_required -async def list_session(tenant_id, chat_id): - if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): - return get_error_data_result(message=f"You don't own the assistant {chat_id}.") - id = request.args.get("id") - name = request.args.get("name") - page_number = int(request.args.get("page", 1)) - items_per_page = int(request.args.get("page_size", 30)) - orderby = request.args.get("orderby", "create_time") - user_id = request.args.get("user_id") - if request.args.get("desc") == "False" or request.args.get("desc") == "false": - desc = False - else: - desc = True - convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name, user_id) - if not convs: - return get_result(data=[]) - for conv in convs: - conv["messages"] = conv.pop("message") - infos = conv["messages"] - for info in infos: - if "prompt" in info: - info.pop("prompt") - conv["chat_id"] = conv.pop("dialog_id") - ref_messages = conv["reference"] - if ref_messages: - messages = conv["messages"] - message_num = 0 - ref_num = 0 - while message_num < len(messages) and ref_num < len(ref_messages): - if messages[message_num]["role"] != "user": - chunk_list = [] - if "chunks" in ref_messages[ref_num]: - chunks = ref_messages[ref_num]["chunks"] - for chunk in chunks: - new_chunk = { - "id": chunk.get("chunk_id", chunk.get("id")), - "content": chunk.get("content_with_weight", chunk.get("content")), - "document_id": chunk.get("doc_id", chunk.get("document_id")), - "document_name": chunk.get("docnm_kwd", chunk.get("document_name")), - "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")), - "image_id": chunk.get("image_id", chunk.get("img_id")), - "positions": chunk.get("positions", chunk.get("position_int")), - } - - chunk_list.append(new_chunk) - messages[message_num]["reference"] = chunk_list - ref_num += 1 - message_num += 1 - del conv["reference"] - return get_result(data=convs) - - @manager.route("/agents//sessions", methods=["GET"]) # noqa: F821 @token_required async def list_agent_session(tenant_id, agent_id): @@ -749,58 +644,6 @@ async def list_agent_session(tenant_id, agent_id): return get_result(data=convs) -@manager.route("/chats//sessions", methods=["DELETE"]) # noqa: F821 -@token_required -async def delete(tenant_id, chat_id): - if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): - return get_error_data_result(message="You don't own the chat") - - errors = [] - success_count = 0 - req = await get_request_json() - if not req: - return get_result() - - ids = req.get("ids") - if not ids: - if req.get("delete_all") is True: - ids = [conv.id for conv in ConversationService.query(dialog_id=chat_id)] - if not ids: - return get_result() - else: - return get_result() - - conv_list = ids - - unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session") - conv_list = unique_conv_ids - - for id in conv_list: - conv = ConversationService.query(id=id, dialog_id=chat_id) - if not conv: - errors.append(f"The chat doesn't own the session {id}") - continue - ConversationService.delete_by_id(id) - success_count += 1 - - if errors: - if success_count > 0: - return get_result(data={"success_count": success_count, "errors": errors}, - message=f"Partially deleted {success_count} sessions with {len(errors)} errors") - else: - return get_error_data_result(message="; ".join(errors)) - - if duplicate_messages: - if success_count > 0: - return get_result( - message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", - data={"success_count": success_count, "errors": duplicate_messages}) - else: - return get_error_data_result(message=";".join(duplicate_messages)) - - return get_result() - - @manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821 @token_required async def delete_agent_session(tenant_id, agent_id): diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 0a433b692..5a205b142 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -44,7 +44,8 @@ class ConversationService(CommonService): else: sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) - sessions = sessions.paginate(page_number, items_per_page) + if items_per_page > 0: + sessions = sessions.paginate(page_number, items_per_page) return list(sessions.dicts()) diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 0b1b7f2f2..659f1957b 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -3456,7 +3456,7 @@ Failure: ```json { "code": 102, - "message": "Name cannot be empty." + "message": "`name` can not be empty." } ``` @@ -3476,8 +3476,7 @@ Updates a session of a specified chat assistant. - `'content-Type: application/json'` - `'Authorization: Bearer '` - Body: - - `"name`: `string` - - `"user_id`: `string` (optional) + - `"name"`: `string` ##### Request example @@ -3494,14 +3493,12 @@ curl --request PUT \ ##### Request Parameter -- `chat_id`: (*Path parameter*) +- `chat_id`: (*Path parameter*) The ID of the associated chat assistant. -- `session_id`: (*Path parameter*) +- `session_id`: (*Path parameter*) The ID of the session to update. -- `"name"`: (*Body Parameter*), `string` +- `"name"`: (*Body Parameter*), `string` The revised name of the session. -- `"user_id"`: (*Body parameter*), `string` - Optional user-defined ID. #### Response @@ -3509,7 +3506,23 @@ Success: ```json { - "code": 0 + "code": 0, + "data": { + "chat_id": "2ca4b22e878011ef88fe0242ac120005", + "create_date": "Fri, 11 Oct 2024 08:46:14 GMT", + "create_time": 1728636374571, + "id": "4606b4ec87ad11efbc4f0242ac120006", + "messages": [ + { + "content": "Hi! I am your assistant, can I help you?", + "role": "assistant" + } + ], + "name": "updated session name", + "update_date": "Fri, 11 Oct 2024 08:46:14 GMT", + "update_time": 1728636374571, + "user_id": "" + } } ``` @@ -3518,7 +3531,7 @@ Failure: ```json { "code": 102, - "message": "Name cannot be empty." + "message": "`name` can not be empty." } ``` @@ -3526,7 +3539,7 @@ Failure: ### List chat assistant's sessions -**GET** `/api/v1/chats/{chat_id}/sessions?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={session_name}&id={session_id}` +**GET** `/api/v1/chats/{chat_id}/sessions?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={session_name}&id={session_id}&user_id={user_id}` Lists sessions associated with a specified chat assistant. @@ -3541,7 +3554,7 @@ Lists sessions associated with a specified chat assistant. ```bash curl --request GET \ - --url http://{address}/api/v1/chats/{chat_id}/sessions?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={session_name}&id={session_id} \ + --url http://{address}/api/v1/chats/{chat_id}/sessions?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&name={session_name}&id={session_id}&user_id={user_id} \ --header 'Authorization: Bearer ' ``` @@ -3552,7 +3565,7 @@ curl --request GET \ - `page`: (*Filter parameter*), `integer` Specifies the page on which the sessions will be displayed. Defaults to `1`. - `page_size`: (*Filter parameter*), `integer` - The number of sessions on each page. Defaults to `30`. + The number of sessions on each page. Defaults to `30`. If set to `0`, an empty list is returned. - `orderby`: (*Filter parameter*), `string` The field by which sessions should be sorted. Available options: - `create_time` (default) @@ -3575,7 +3588,7 @@ Success: "code": 0, "data": [ { - "chat": "2ca4b22e878011ef88fe0242ac120005", + "chat_id": "2ca4b22e878011ef88fe0242ac120005", "create_date": "Fri, 11 Oct 2024 08:46:43 GMT", "create_time": 1728636403974, "id": "578d541e87ad11ef96b90242ac120006", @@ -3586,8 +3599,10 @@ Success: } ], "name": "new session", + "reference": [], "update_date": "Fri, 11 Oct 2024 08:46:43 GMT", - "update_time": 1728636403974 + "update_time": 1728636403974, + "user_id": "" } ] } @@ -3604,6 +3619,202 @@ Failure: --- +### Get chat assistant's session + +**GET** `/api/v1/chats/{chat_id}/sessions/{session_id}` + +Gets a specific session of a specified chat assistant, including its messages, references, and avatar. + +#### Request + +- Method: GET +- URL: `/api/v1/chats/{chat_id}/sessions/{session_id}` +- Headers: + - `'Authorization: Bearer '` + +##### Request example + +```bash +curl --request GET \ + --url http://{address}/api/v1/chats/{chat_id}/sessions/{session_id} \ + --header 'Authorization: Bearer ' +``` + +##### Request Parameters + +- `chat_id`: (*Path parameter*) + The ID of the associated chat assistant. +- `session_id`: (*Path parameter*) + The ID of the session to retrieve. + +#### Response + +Success: + +```json +{ + "code": 0, + "data": { + "chat_id": "2ca4b22e878011ef88fe0242ac120005", + "id": "4606b4ec87ad11efbc4f0242ac120006", + "name": "new session", + "avatar": "data:image/png;base64,...", + "messages": [ + { + "content": "Hi! I am your assistant, can I help you?", + "role": "assistant" + } + ], + "reference": [] + } +} +``` + +Failure: + +```json +{ + "code": 102, + "message": "Session not found!" +} +``` + +--- + +### Delete a message from a chat assistant's session + +**DELETE** `/api/v1/chats/{chat_id}/sessions/{session_id}/messages/{msg_id}` + +Deletes a user message and its paired assistant reply from a specified chat assistant session. + +#### Request + +- Method: DELETE +- URL: `/api/v1/chats/{chat_id}/sessions/{session_id}/messages/{msg_id}` +- Headers: + - `'Authorization: Bearer '` + +##### Request example + +```bash +curl --request DELETE \ + --url http://{address}/api/v1/chats/{chat_id}/sessions/{session_id}/messages/{msg_id} \ + --header 'Authorization: Bearer ' +``` + +##### Request Parameters + +- `chat_id`: (*Path parameter*) + The ID of the associated chat assistant. +- `session_id`: (*Path parameter*) + The ID of the session that owns the message. +- `msg_id`: (*Path parameter*) + The ID of the message to delete. + +#### Response + +Success: returns the updated session object. + +```json +{ + "code": 0, + "data": { + "chat_id": "2ca4b22e878011ef88fe0242ac120005", + "id": "4606b4ec87ad11efbc4f0242ac120006", + "messages": [], + "reference": [] + } +} +``` + +Failure: + +```json +{ + "code": 102, + "message": "Session not found!" +} +``` + +--- + +### Update message feedback in a chat assistant's session + +**PUT** `/api/v1/chats/{chat_id}/sessions/{session_id}/messages/{msg_id}/feedback` + +Updates feedback for an assistant message in a specified chat assistant session. + +#### Request + +- Method: PUT +- URL: `/api/v1/chats/{chat_id}/sessions/{session_id}/messages/{msg_id}/feedback` +- Headers: + - `'Content-Type: application/json'` + - `'Authorization: Bearer '` +- Body: + - `"thumbup"`: `boolean` + - `"feedback"`: `string` (optional) + +##### Request example + +```bash +curl --request PUT \ + --url http://{address}/api/v1/chats/{chat_id}/sessions/{session_id}/messages/{msg_id}/feedback \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer ' \ + --data '{ + "thumbup": false, + "feedback": "The answer missed the cited document." + }' +``` + +##### Request Parameters + +- `chat_id`: (*Path parameter*) + The ID of the associated chat assistant. +- `session_id`: (*Path parameter*) + The ID of the session that owns the message. +- `msg_id`: (*Path parameter*) + The ID of the assistant message to update. +- `"thumbup"`: (*Body parameter*), `boolean` + Whether the assistant message is marked as positive feedback. +- `"feedback"`: (*Body parameter*), `string` + Optional feedback text, typically used when `"thumbup"` is `false`. + +#### Response + +Success: returns the updated session object. + +```json +{ + "code": 0, + "data": { + "chat_id": "2ca4b22e878011ef88fe0242ac120005", + "id": "4606b4ec87ad11efbc4f0242ac120006", + "messages": [ + { + "id": "message-id", + "role": "assistant", + "content": "Here is the answer.", + "thumbup": false, + "feedback": "The answer missed the cited document." + } + ] + } +} +``` + +Failure: + +```json +{ + "code": 102, + "message": "Session not found!" +} +``` + +--- + ### Delete chat assistant's sessions **DELETE** `/api/v1/chats/{chat_id}/sessions` @@ -5057,9 +5268,159 @@ Failure: --- +### Text-to-speech + +**POST** `/api/v1/chats/tts` + +Converts text to speech audio using the tenant's default TTS model, returning a streaming audio response. + +#### Request + +- Method: POST +- URL: `/api/v1/chats/tts` +- Headers: + - `'Content-Type: application/json'` + - `'Authorization: Bearer '` +- Body: + - `"text"`: `string` *(Required)* The text to synthesize. + +##### Request example + +```bash +curl --request POST \ + --url http://{address}/api/v1/chats/tts \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer ' \ + --output audio.mp3 \ + --data '{"text": "Hello, how can I help you today?"}' +``` + +#### Response + +Success: binary `audio/mpeg` stream with headers `Cache-Control: no-cache`, `Connection: keep-alive`, `X-Accel-Buffering: no`. + +Failure: + +```json +{ + "code": 102, + "message": "No default TTS model is set" +} +``` + +--- + +### Speech-to-text + +**POST** `/api/v1/chats/transcriptions` + +Transcribes an audio file using the tenant's default ASR (automatic speech recognition) model. + +#### Request + +- Method: POST +- URL: `/api/v1/chats/transcriptions` +- Headers: + - `'Authorization: Bearer '` +- Body (multipart/form-data): + - `"file"`: audio file (`.wav`, `.mp3`, `.m4a`, `.aac`, `.flac`, `.ogg`, `.webm`, `.opus`, `.wma`) + - `"stream"`: `string` `"true"` for SSE streaming, `"false"` (default) for a single JSON response. + +##### Request example + +```bash +curl --request POST \ + --url http://{address}/api/v1/chats/transcriptions \ + --header 'Authorization: Bearer ' \ + --form file=@recording.wav \ + --form stream=false +``` + +#### Response + +Success (non-streaming): + +```json +{ + "code": 0, + "data": { + "text": "Hello, how can I help you today?" + } +} +``` + +Success (streaming): SSE events with `data: {"event": "partial", "text": "..."}`. + +Failure: + +```json +{ + "code": 102, + "message": "Unsupported audio format: .mp4. Allowed: .aac, .flac, .m4a, .mp3, .ogg, .opus, .wav, .webm, .wma" +} +``` + +--- + +### Generate mind map + +**POST** `/api/v1/chats/mindmap` + +Generates a mind map from a question and a set of knowledge base IDs. + +#### Request + +- Method: POST +- URL: `/api/v1/chats/mindmap` +- Headers: + - `'Content-Type: application/json'` + - `'Authorization: Bearer '` +- Body: + - `"question"`: `string` *(Required)* The central question or topic. + - `"kb_ids"`: `list[string]` *(Required)* Knowledge base IDs to search. + - `"search_id"`: `string` *(Optional)* ID of a saved search configuration to merge additional `kb_ids` and settings. + +##### Request example + +```bash +curl --request POST \ + --url http://{address}/api/v1/chats/mindmap \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer ' \ + --data '{ + "question": "What is retrieval-augmented generation?", + "kb_ids": ["kb-abc123"] + }' +``` + +#### Response + +Success: + +```json +{ + "code": 0, + "data": { + "name": "Retrieval-Augmented Generation", + "children": [...] + } +} +``` + +Failure: + +```json +{ + "code": 500, + "message": "..." +} +``` + +--- + ### Generate related questions -**POST** `/api/v1/sessions/related_questions` +**POST** `/api/v1/chats/related_questions` Generates five to ten alternative question strings from the user's original query to retrieve more relevant search results. @@ -5074,25 +5435,23 @@ The chat model autonomously determines the number of questions to generate based #### Request - Method: POST -- URL: `/api/v1/sessions/related_questions` +- URL: `/api/v1/chats/related_questions` - Headers: - `'content-Type: application/json'` - `'Authorization: Bearer '` - Body: - - `"question"`: `string` - - `"industry"`: `string` + - `"question"`: `string` *(Required)* The original user question. + - `"search_id"`: `string` *(Optional)* ID of a saved search configuration to use custom LLM settings. ##### Request example ```bash curl --request POST \ - --url http://{address}/api/v1/sessions/related_questions \ + --url http://{address}/api/v1/chats/related_questions \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer ' \ - --data ' - { - "question": "What are the key advantages of Neovim over Vim?", - "industry": "software_development" + --data '{ + "question": "What are the key advantages of Neovim over Vim?" }' ``` @@ -5100,8 +5459,8 @@ curl --request POST \ - `"question"`: (*Body Parameter*), `string` The original user question. -- `"industry"`: (*Body Parameter*), `string` - Industry of the question. +- `"search_id"`: (*Body Parameter*), `string` + ID of a saved search configuration to use custom LLM settings. If provided, the LLM model and generation settings from the search configuration will be used. #### Response diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index 8d522cc50..170955026 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -1469,12 +1469,13 @@ session.update({"name": "updated_name"}) ```python Chat.list_sessions( - page: int = 1, - page_size: int = 30, - orderby: str = "create_time", + page: int = 1, + page_size: int = 30, + orderby: str = "create_time", desc: bool = True, id: str = None, - name: str = None + name: str = None, + user_id: str = None ) -> list[Session] ``` @@ -1509,6 +1510,10 @@ The ID of the chat session to retrieve. Defaults to `None`. The name of the chat session to retrieve. Defaults to `None`. +##### user_id: `str` + +The optional user-defined ID to filter sessions by. Defaults to `None`. + #### Returns - Success: A list of `Session` objects associated with the current chat assistant. diff --git a/sdk/python/ragflow_sdk/modules/chat.py b/sdk/python/ragflow_sdk/modules/chat.py index 82374b2d5..18822eb4f 100644 --- a/sdk/python/ragflow_sdk/modules/chat.py +++ b/sdk/python/ragflow_sdk/modules/chat.py @@ -48,10 +48,10 @@ class Chat(Base): res = res.json() if res.get("code") == 0: return Session(self.rag, res["data"]) - raise Exception(res["message"]) + raise Exception(res.get("message")) - def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str = None, name: str = None) -> list[Session]: - res = self.get(f"/chats/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) + def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str = None, name: str = None, user_id: str = None) -> list[Session]: + res = self.get(f"/chats/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name, "user_id": user_id}) res = res.json() if res.get("code") == 0: result_list = [] diff --git a/test/playwright/e2e/test_next_apps_chat.py b/test/playwright/e2e/test_next_apps_chat.py index 135b10af2..e0169a8a5 100644 --- a/test/playwright/e2e/test_next_apps_chat.py +++ b/test/playwright/e2e/test_next_apps_chat.py @@ -678,7 +678,9 @@ def mm_step_12_composer_and_single_send(ctx: FlowContext, step, snap): def _on_completion_request(req): if ( req.method.upper() in MM_REQUEST_METHOD_WHITELIST - and "/conversation/completion" in req.url + and "/api/v1/chats/" in req.url + and "/sessions/" in req.url + and req.url.rstrip("/").endswith("/completions") ): completion_payloads.append(_mm_payload_from_request(req)) @@ -747,7 +749,7 @@ def mm_step_12_composer_and_single_send(ctx: FlowContext, step, snap): page.remove_listener("request", _on_completion_request) attach_path.unlink(missing_ok=True) - assert completion_payloads, "no /conversation/completion request was captured" + assert completion_payloads, "no chat session completion request was captured" payloads_with_messages = [p for p in completion_payloads if p.get("messages")] assert payloads_with_messages, "completion requests did not include messages" diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py index dd67e8c72..f0851f4a2 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py @@ -58,6 +58,28 @@ class _DummyArgs(dict): return [value] +class _StubHeaders: + def __init__(self): + self._items = [] + + def add_header(self, key, value): + self._items.append((key, value)) + + def get(self, key, default=None): + for existing_key, value in reversed(self._items): + if existing_key == key: + return value + return default + + +class _StubResponse: + def __init__(self, body=None, mimetype=None, content_type=None): + self.body = body + self.mimetype = mimetype + self.content_type = content_type + self.headers = _StubHeaders() + + def _passthrough_login_required(func): @wraps(func) async def _wrapper(*args, **kwargs): @@ -125,6 +147,7 @@ def _load_chat_module(monkeypatch): quart_mod = ModuleType("quart") quart_mod.request = SimpleNamespace(args=_DummyArgs()) + quart_mod.Response = _StubResponse monkeypatch.setitem(sys.modules, "quart", quart_mod) api_pkg = ModuleType("api") @@ -144,6 +167,11 @@ def _load_chat_module(monkeypatch): common_constants_mod = ModuleType("common.constants") + class _StubLLMType(str, Enum): + CHAT = "chat" + IMAGE2TEXT = "image2text" + RERANK = "rerank" + class _StubRetCode(int, Enum): SUCCESS = 0 DATA_ERROR = 102 @@ -153,6 +181,7 @@ def _load_chat_module(monkeypatch): VALID = "1" INVALID = "0" + common_constants_mod.LLMType = _StubLLMType common_constants_mod.RetCode = _StubRetCode common_constants_mod.StatusEnum = _StubStatusEnum monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod) @@ -213,8 +242,42 @@ def _load_chat_module(monkeypatch): return [], 0 dialog_service_mod.DialogService = _StubDialogService + dialog_service_mod.async_ask = lambda *_args, **_kwargs: None + dialog_service_mod.async_chat = lambda *_args, **_kwargs: None + dialog_service_mod.gen_mindmap = lambda *_args, **_kwargs: None monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod) + conversation_service_mod = ModuleType("api.db.services.conversation_service") + + class _StubConversationService: + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def get_list(*_args, **_kwargs): + return [] + + @staticmethod + def get_by_id(_session_id): + return False, None + + @staticmethod + def update_by_id(_session_id, _payload): + return True + + @staticmethod + def delete_by_id(_session_id): + return True + + @staticmethod + def save(**_kwargs): + return True + + conversation_service_mod.ConversationService = _StubConversationService + conversation_service_mod.structure_answer = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "api.db.services.conversation_service", conversation_service_mod) + kb_service_mod = ModuleType("api.db.services.knowledgebase_service") class _StubKnowledgebaseService: @@ -253,6 +316,24 @@ def _load_chat_module(monkeypatch): tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + llm_service_mod = ModuleType("api.db.services.llm_service") + + class _StubLLMBundle: + def __init__(self, *_args, **_kwargs): + pass + + llm_service_mod.LLMBundle = _StubLLMBundle + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) + + search_service_mod = ModuleType("api.db.services.search_service") + search_service_mod.SearchService = SimpleNamespace() + monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod) + + tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") + tenant_model_service_mod.get_model_config_by_type_and_name = lambda *_args, **_kwargs: {} + tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) + user_service_mod = ModuleType("api.db.services.user_service") class _StubTenantService: @@ -283,12 +364,29 @@ def _load_chat_module(monkeypatch): api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "data": data, "message": message} api_utils_mod.get_request_json = lambda: _AwaitableValue({}) api_utils_mod.server_error_response = lambda ex: {"code": 500, "data": None, "message": str(ex)} + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) tenant_utils_mod = ModuleType("api.utils.tenant_utils") tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, req: req monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + + rag_prompts_pkg = ModuleType("rag.prompts") + rag_prompts_pkg.__path__ = [str(repo_root / "rag" / "prompts")] + monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg) + + rag_prompts_generator_mod = ModuleType("rag.prompts.generator") + rag_prompts_generator_mod.chunks_format = lambda reference: reference.get("chunks", []) if isinstance(reference, dict) else [] + monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts_generator_mod) + + rag_prompts_template_mod = ModuleType("rag.prompts.template") + rag_prompts_template_mod.load_prompt = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "rag.prompts.template", rag_prompts_template_mod) + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() @@ -741,24 +839,148 @@ def test_list_chats_keeps_zero_pagination_semantics(monkeypatch): assert calls[-1] == (0, 2) assert len(res["data"]["chats"]) == 1 + +@pytest.mark.p2 +def test_chat_session_create_and_update_guard_matrix_unit(monkeypatch): + module = _load_chat_module(monkeypatch) + + _set_request_json(monkeypatch, module, {"name": "session"}) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + res = _run(module.create_session.__wrapped__("chat-1")) + assert res["message"] == "No authorization." + + dia = SimpleNamespace(prompt_config={"prologue": "hello"}) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [dia]) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, dia)) + monkeypatch.setattr(module.ConversationService, "save", lambda **_kwargs: None) + monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None)) + res = _run(module.create_session.__wrapped__("chat-1")) + assert "Fail to create a session" in res["message"] + + _set_request_json(monkeypatch, module, {}) + monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: []) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert res["message"] == "Session not found!" + + monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: [SimpleNamespace(id="session-1")]) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert res["message"] == "No authorization." + + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + _set_request_json(monkeypatch, module, {"message": []}) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert "`messages` cannot be changed." in res["message"] + + _set_request_json(monkeypatch, module, {"reference": []}) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert "`reference` cannot be changed." in res["message"] + + _set_request_json(monkeypatch, module, {"name": ""}) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert "`name` can not be empty." in res["message"] + + _set_request_json(monkeypatch, module, {"name": "renamed"}) + monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: False) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert res["message"] == "Session not found!" + + +@pytest.mark.p2 +def test_chat_session_list_projection_unit(monkeypatch): + module = _load_chat_module(monkeypatch) + monkeypatch.setattr( module, "request", SimpleNamespace( args=SimpleNamespace( get=lambda key, default=None: { - "keywords": "", - "page_size": 2, + "page": 1, + "page_size": 30, "orderby": "create_time", "desc": "true", - }.get(key, default), - getlist=lambda _key: [], + "id": None, + "name": None, + "user_id": None, + }.get(key, default) ) ), ) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + monkeypatch.setattr( + module.ConversationService, + "get_list", + lambda *_args, **_kwargs: [ + { + "id": "session-1", + "dialog_id": "chat-1", + "message": [{"role": "assistant", "content": "hello"}], + "reference": [], + } + ], + ) - res = module.list_chats.__wrapped__() + res = module.list_sessions.__wrapped__("chat-1") + assert res["data"][0]["chat_id"] == "chat-1" + assert res["data"][0]["messages"][0]["content"] == "hello" + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + "page": 1, + "page_size": 0, + "orderby": "create_time", + "desc": "true", + "id": None, + "name": None, + "user_id": None, + }.get(key, default) + ) + ), + ) + res = module.list_sessions.__wrapped__("chat-1") + assert res["data"] == [] + + +@pytest.mark.p2 +def test_chat_session_delete_routes_partial_duplicate_unit(monkeypatch): + module = _load_chat_module(monkeypatch) + + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + _set_request_json(monkeypatch, module, {}) + res = _run(module.delete_sessions.__wrapped__("chat-1")) assert res["code"] == 0 - assert calls[-1] == (0, 2) - assert len(res["data"]["chats"]) == 1 + + monkeypatch.setattr(module.ConversationService, "delete_by_id", lambda *_args, **_kwargs: True) + + def _conversation_query(**kwargs): + if "dialog_id" in kwargs and "id" not in kwargs: + return [SimpleNamespace(id="seed")] + if kwargs.get("id") == "ok": + return [SimpleNamespace(id="ok")] + return [] + + monkeypatch.setattr(module.ConversationService, "query", _conversation_query) + + _set_request_json(monkeypatch, module, {"ids": ["ok", "bad"]}) + monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, [])) + res = _run(module.delete_sessions.__wrapped__("chat-1")) + assert res["code"] == 0 + assert res["data"]["success_count"] == 1 + assert res["data"]["errors"] == ["The chat doesn't own the session bad"] + + _set_request_json(monkeypatch, module, {"ids": ["bad"]}) + monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, [])) + res = _run(module.delete_sessions.__wrapped__("chat-1")) + assert res["message"] == "The chat doesn't own the session bad" + + _set_request_json(monkeypatch, module, {"ids": ["ok", "ok"]}) + monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (["ok"], ["Duplicate session ids: ok"])) + res = _run(module.delete_sessions.__wrapped__("chat-1")) + assert res["code"] == 0 + assert res["data"]["success_count"] == 1 + assert res["data"]["errors"] == ["Duplicate session ids: ok"] diff --git a/test/testcases/test_http_api/test_session_management/test_create_session_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_create_session_with_chat_assistant.py index 322fd1b7a..c91727b89 100644 --- a/test/testcases/test_http_api/test_session_management/test_create_session_with_chat_assistant.py +++ b/test/testcases/test_http_api/test_session_management/test_create_session_with_chat_assistant.py @@ -26,12 +26,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), - ( - RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", - ), + (None, 401, ""), + (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 401, ""), ], ) def test_invalid_auth(self, invalid_auth, expected_code, expected_message): @@ -74,7 +70,7 @@ class TestSessionWithChatAssistantCreate: "chat_assistant_id, expected_code, expected_message", [ ("", 100, ""), - ("invalid_chat_assistant_id", 102, "You do not own the assistant."), + ("invalid_chat_assistant_id", 109, "No authorization."), ], ) def test_invalid_chat_assistant_id(self, HttpApiAuth, chat_assistant_id, expected_code, expected_message): @@ -115,5 +111,5 @@ class TestSessionWithChatAssistantCreate: res = delete_chat_assistants(HttpApiAuth, {"ids": [chat_assistant_ids[0]]}) assert res["code"] == 0 res = create_session_with_chat_assistant(HttpApiAuth, chat_assistant_ids[0], {"name": "valid_name"}) - assert res["code"] == 102 - assert res["message"] == "You do not own the assistant." + assert res["code"] == 109 + assert res["message"] == "No authorization." diff --git a/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py index 637cb1f1d..ea67018ec 100644 --- a/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py +++ b/test/testcases/test_http_api/test_session_management/test_delete_sessions_with_chat_assistant.py @@ -26,12 +26,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), - ( - RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", - ), + (None, 401, ""), + (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 401, ""), ], ) def test_invalid_auth(self, invalid_auth, expected_code, expected_message): @@ -48,8 +44,8 @@ class TestSessionWithChatAssistantDelete: ("", 100, ""), ( "invalid_chat_assistant_id", - 102, - "You don't own the chat", + 109, + "No authorization.", ), ], ) @@ -146,6 +142,7 @@ class TestSessionWithChatAssistantDelete: pytest.param("not json", 100, """AttributeError("\'str\' object has no attribute \'get\'")""", 5, marks=pytest.mark.skip), pytest.param(lambda r: {"ids": r[:1]}, 0, "", 4, marks=pytest.mark.p3), pytest.param(lambda r: {"ids": r}, 0, "", 0, marks=pytest.mark.p1), + pytest.param({"delete_all": True}, 0, "", 0, marks=pytest.mark.p1), pytest.param({"ids": []}, 0, "", 5, marks=pytest.mark.p3), ], ) diff --git a/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py index fb1f1737a..8db09d520 100644 --- a/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py +++ b/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py @@ -27,12 +27,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), - ( - RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", - ), + (None, 401, ""), + (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 401, ""), ], ) def test_invalid_auth(self, invalid_auth, expected_code, expected_message): @@ -246,5 +242,5 @@ class TestSessionsWithChatAssistantList: assert res["code"] == 0 res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id) - assert res["code"] == 102 - assert "You don't own the assistant" in res["message"] + assert res["code"] == 109 + assert res["message"] == "No authorization." diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index c305cc20d..df28d68cb 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -530,18 +530,7 @@ def _load_session_module(monkeypatch): def test_create_and_update_guard_matrix(monkeypatch): module = _load_session_module(monkeypatch) - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"name": "session"})) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - res = _run(inspect.unwrap(module.create)("tenant-1", "chat-1")) - assert res["message"] == "You do not own the assistant." - - dia = SimpleNamespace(prompt_config={"prologue": "hello"}) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [dia]) - monkeypatch.setattr(module.ConversationService, "save", lambda **_kwargs: None) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None)) - res = _run(inspect.unwrap(module.create)("tenant-1", "chat-1")) - assert "Fail to create a session" in res["message"] - + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args())) monkeypatch.setattr(module.UserCanvasService, "query", lambda **_kwargs: [SimpleNamespace(id="agent-1")]) @@ -556,34 +545,6 @@ def test_create_and_update_guard_matrix(monkeypatch): res = _run(inspect.unwrap(module.create_agent_session)("tenant-1", "agent-1")) assert res["message"] == "You cannot access the agent." - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) - monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: []) - res = _run(inspect.unwrap(module.update)("tenant-1", "chat-1", "session-1")) - assert res["message"] == "Session does not exist" - - monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: [SimpleNamespace(id="session-1")]) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - res = _run(inspect.unwrap(module.update)("tenant-1", "chat-1", "session-1")) - assert res["message"] == "You do not own the session" - - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"message": []})) - res = _run(inspect.unwrap(module.update)("tenant-1", "chat-1", "session-1")) - assert "`message` can not be change" in res["message"] - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"reference": []})) - res = _run(inspect.unwrap(module.update)("tenant-1", "chat-1", "session-1")) - assert "`reference` can not be change" in res["message"] - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"name": ""})) - res = _run(inspect.unwrap(module.update)("tenant-1", "chat-1", "session-1")) - assert "`name` can not be empty" in res["message"] - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"name": "renamed"})) - monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: False) - res = _run(inspect.unwrap(module.update)("tenant-1", "chat-1", "session-1")) - assert res["message"] == "Session updates error" - @pytest.mark.p2 def test_chat_completion_metadata_and_stream_paths(monkeypatch): @@ -929,44 +890,6 @@ def test_agent_completions_stream_and_nonstream_unit(monkeypatch): assert res["data"].startswith("**ERROR**") -@pytest.mark.p2 -def test_list_session_projection_unit(monkeypatch): - module = _load_session_module(monkeypatch) - - monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({}))) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) - - convs = [ - { - "id": "session-1", - "dialog_id": "chat-1", - "message": [{"role": "assistant", "content": "hello", "prompt": "internal"}], - "reference": [ - { - "chunks": [ - { - "chunk_id": "chunk-1", - "content_with_weight": "weighted", - "doc_id": "doc-1", - "docnm_kwd": "doc-name", - "kb_id": "kb-1", - "image_id": "img-1", - "positions": [1, 2], - } - ] - } - ], - } - ] - monkeypatch.setattr(module.ConversationService, "get_list", lambda *_args, **_kwargs: convs) - - res = _run(inspect.unwrap(module.list_session)("tenant-1", "chat-1")) - assert res["data"][0]["chat_id"] == "chat-1" - assert "reference" not in res["data"][0] - assert "prompt" not in res["data"][0]["messages"][0] - assert res["data"][0]["messages"][0]["reference"][0]["positions"] == [1, 2] - - @pytest.mark.p2 def test_list_agent_session_projection_unit(monkeypatch): module = _load_session_module(monkeypatch) @@ -1020,41 +943,6 @@ def test_list_agent_session_projection_unit(monkeypatch): def test_delete_routes_partial_duplicate_unit(monkeypatch): module = _load_session_module(monkeypatch) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) - res = _run(inspect.unwrap(module.delete)("tenant-1", "chat-1")) - assert res["code"] == 0 - - monkeypatch.setattr(module.ConversationService, "delete_by_id", lambda *_args, **_kwargs: True) - - def _conversation_query(**kwargs): - if "id" not in kwargs: - return [SimpleNamespace(id="seed")] - if kwargs["id"] == "ok": - return [SimpleNamespace(id="ok")] - return [] - - monkeypatch.setattr(module.ConversationService, "query", _conversation_query) - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"ids": ["ok", "bad"]})) - monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, [])) - res = _run(inspect.unwrap(module.delete)("tenant-1", "chat-1")) - assert res["code"] == 0 - assert res["data"]["success_count"] == 1 - assert res["data"]["errors"] == ["The chat doesn't own the session bad"] - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"ids": ["bad"]})) - monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, [])) - res = _run(inspect.unwrap(module.delete)("tenant-1", "chat-1")) - assert res["message"] == "The chat doesn't own the session bad" - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"ids": ["ok", "ok"]})) - monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (["ok"], ["Duplicate session ids: ok"])) - res = _run(inspect.unwrap(module.delete)("tenant-1", "chat-1")) - assert res["code"] == 0 - assert res["data"]["success_count"] == 1 - assert res["data"]["errors"] == ["Duplicate session ids: ok"] - monkeypatch.setattr(module.UserCanvasService, "query", lambda **_kwargs: [SimpleNamespace(id="agent-1")]) monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) res = _run(inspect.unwrap(module.delete_agent_session)("tenant-1", "agent-1")) diff --git a/test/testcases/test_http_api/test_session_management/test_update_session_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_update_session_with_chat_assistant.py index fa22b27aa..7694c99c1 100644 --- a/test/testcases/test_http_api/test_session_management/test_update_session_with_chat_assistant.py +++ b/test/testcases/test_http_api/test_session_management/test_update_session_with_chat_assistant.py @@ -27,12 +27,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), - ( - RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", - ), + (None, 401, ""), + (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 401, ""), ], ) def test_invalid_auth(self, invalid_auth, expected_code, expected_message): @@ -72,7 +68,7 @@ class TestSessionWithChatAssistantUpdate: @pytest.mark.parametrize( "chat_assistant_id, expected_code, expected_message", [ - (INVALID_ID_32, 102, "Session does not exist"), + (INVALID_ID_32, 109, "No authorization."), ], ) def test_invalid_chat_assistant_id(self, HttpApiAuth, add_sessions_with_chat_assistant_func, chat_assistant_id, expected_code, expected_message): @@ -86,7 +82,7 @@ class TestSessionWithChatAssistantUpdate: "session_id, expected_code, expected_message", [ ("", 100, ""), - ("invalid_session_id", 102, "Session does not exist"), + ("invalid_session_id", 102, "Session not found!"), ], ) def test_invalid_session_id(self, HttpApiAuth, add_sessions_with_chat_assistant_func, session_id, expected_code, expected_message): @@ -145,5 +141,5 @@ class TestSessionWithChatAssistantUpdate: chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func delete_chat_assistants(HttpApiAuth, {"ids": [chat_assistant_id]}) res = update_session_with_chat_assistant(HttpApiAuth, chat_assistant_id, session_ids[0], {"name": "valid_name"}) - assert res["code"] == 102 - assert res["message"] == "You do not own the session" + assert res["code"] == 109 + assert res["message"] == "No authorization." diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_chat_crud_unit.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_chat_crud_unit.py new file mode 100644 index 000000000..e713f43ff --- /dev/null +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_chat_crud_unit.py @@ -0,0 +1,87 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from ragflow_sdk import RAGFlow +from ragflow_sdk.modules.chat import Chat +from ragflow_sdk.modules.session import Session + + +class _DummyResponse: + def __init__(self, payload): + self._payload = payload + + def json(self): + return self._payload + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +@pytest.mark.p2 +def test_chat_create_session_raises_server_error_message(monkeypatch): + client = RAGFlow("token", "http://localhost:9380") + chat = Chat(client, {"id": "chat-1"}) + + monkeypatch.setattr( + chat, + "post", + lambda *_args, **_kwargs: _DummyResponse({"code": 102, "message": "`name` can not be empty."}), + ) + + with pytest.raises(Exception) as exception_info: + chat.create_session(name="") + assert "`name` can not be empty." in str(exception_info.value), str(exception_info.value) + + +@pytest.mark.p2 +def test_chat_list_sessions_forwards_restful_query_params(monkeypatch): + client = RAGFlow("token", "http://localhost:9380") + chat = Chat(client, {"id": "chat-1"}) + calls = [] + + def _ok_get(path, params=None): + calls.append((path, params)) + return _DummyResponse( + { + "code": 0, + "data": [ + {"id": "session-1", "chat_id": "chat-1", "name": "one"}, + {"id": "session-2", "chat_id": "chat-1", "name": "two"}, + ], + } + ) + + monkeypatch.setattr(chat, "get", _ok_get) + + sessions = chat.list_sessions(page=2, page_size=2, orderby="create_time", desc=False, id="session-1", name="one", user_id="user-1") + assert len(sessions) == 2, str(sessions) + assert all(isinstance(item, Session) for item in sessions), str(sessions) + assert calls[-1][0] == "/chats/chat-1/sessions" + assert calls[-1][1]["page_size"] == 2 + assert calls[-1][1]["name"] == "one" + assert calls[-1][1]["user_id"] == "user-1" + + all_sessions = chat.list_sessions(page_size=0) + assert len(all_sessions) == 2, str(all_sessions) + assert calls[-1][1]["page_size"] == 0 diff --git a/test/testcases/test_web_api/test_conversation_app/test_conversation_routes_unit.py b/test/testcases/test_web_api/test_conversation_app/test_conversation_routes_unit.py deleted file mode 100644 index 2dd862759..000000000 --- a/test/testcases/test_web_api/test_conversation_app/test_conversation_routes_unit.py +++ /dev/null @@ -1,801 +0,0 @@ -# -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import asyncio -import importlib.util -import sys -from copy import deepcopy -from pathlib import Path -from types import ModuleType, SimpleNamespace - -import pytest -from anyio import Path as AsyncPath - - -class _DummyManager: - def route(self, *_args, **_kwargs): - def decorator(func): - return func - - return decorator - - -class _AwaitableValue: - def __init__(self, value): - self._value = value - - def __await__(self): - async def _co(): - return self._value - - return _co().__await__() - - -class _DummyRequest: - def __init__(self, *, args=None, headers=None, form=None, files=None): - self.args = args or {} - self.headers = headers or {} - self.form = _AwaitableValue(form or {}) - self.files = _AwaitableValue(files or {}) - self.method = "POST" - self.content_length = 0 - - -class _DummyConversation: - def __init__(self, *, conv_id="conv-1", dialog_id="dialog-1", message=None, reference=None): - self.id = conv_id - self.dialog_id = dialog_id - self.message = message if message is not None else [] - self.reference = reference if reference is not None else [] - - def to_dict(self): - return { - "id": self.id, - "dialog_id": self.dialog_id, - "message": deepcopy(self.message), - "reference": deepcopy(self.reference), - } - - -class _DummyDialog: - def __init__(self, *, dialog_id="dialog-1", tenant_id="tenant-1", icon="avatar.png"): - self.id = dialog_id - self.tenant_id = tenant_id - self.icon = icon - self.prompt_config = {"prologue": "hello"} - self.llm_id = "" - self.llm_setting = {} - - def to_dict(self): - return { - "id": self.id, - "icon": self.icon, - "tenant_id": self.tenant_id, - "prompt_config": deepcopy(self.prompt_config), - } - - -class _DummyUploadedFile: - def __init__(self, filename): - self.filename = filename - self.saved_path = None - - async def save(self, path): - self.saved_path = path - await AsyncPath(path).write_bytes(b"audio-bytes") - - -def _run(coro): - return asyncio.run(coro) - - -def _load_conversation_module(monkeypatch): - repo_root = Path(__file__).resolve().parents[4] - - common_pkg = ModuleType("common") - common_pkg.__path__ = [str(repo_root / "common")] - monkeypatch.setitem(sys.modules, "common", common_pkg) - - deepdoc_pkg = ModuleType("deepdoc") - deepdoc_parser_pkg = ModuleType("deepdoc.parser") - deepdoc_parser_pkg.__path__ = [] - - class _StubPdfParser: - pass - - class _StubExcelParser: - pass - - class _StubDocxParser: - pass - - deepdoc_parser_pkg.PdfParser = _StubPdfParser - deepdoc_parser_pkg.ExcelParser = _StubExcelParser - deepdoc_parser_pkg.DocxParser = _StubDocxParser - deepdoc_pkg.parser = deepdoc_parser_pkg - monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg) - monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg) - - deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser") - deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser - monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module) - - deepdoc_parser_utils = ModuleType("deepdoc.parser.utils") - deepdoc_parser_utils.get_text = lambda *_args, **_kwargs: "" - monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils) - monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost")) - - apps_mod = ModuleType("api.apps") - apps_mod.current_user = SimpleNamespace(id="user-1") - apps_mod.login_required = lambda func: func - monkeypatch.setitem(sys.modules, "api.apps", apps_mod) - - # Create user_service module with TenantService stub if not already exists - if "api.db.services.user_service" not in sys.modules: - user_service_mod = ModuleType("api.db.services.user_service") - user_service_mod.UserService = SimpleNamespace() # Dummy UserService class - user_service_mod.TenantService = SimpleNamespace( - get_info_by=lambda _uid: [], - get_by_id=lambda _uid: (False, None) - ) - user_service_mod.UserTenantService = SimpleNamespace( - query=lambda **_kwargs: [] - ) - monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) - - module_name = "test_conversation_routes_unit_module" - module_path = repo_root / "api" / "apps" / "conversation_app.py" - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - module.manager = _DummyManager() - monkeypatch.setitem(sys.modules, module_name, module) - spec.loader.exec_module(module) - return module - - -def _set_request_json(monkeypatch, module, payload): - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload))) - - -async def _read_sse_text(response): - chunks = [] - async for chunk in response.response: - if isinstance(chunk, bytes): - chunks.append(chunk.decode("utf-8")) - else: - chunks.append(chunk) - return "".join(chunks) - - -@pytest.fixture(scope="session") -def auth(): - return "unit-auth" - - -@pytest.fixture(scope="session", autouse=True) -def set_tenant_info(): - return None - - -@pytest.mark.p2 -def test_set_conversation_update_create_and_errors(monkeypatch): - module = _load_conversation_module(monkeypatch) - - long_name = "n" * 300 - create_payload = { - "conversation_id": "conv-new", - "dialog_id": "dialog-1", - "is_new": True, - "name": long_name, - } - _set_request_json(monkeypatch, module, create_payload) - - saved = {} - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialog())) - monkeypatch.setattr(module.ConversationService, "save", lambda **kwargs: saved.update(kwargs) or True) - res = _run(module.set_conversation()) - assert res["code"] == 0 - assert len(res["data"]["name"]) == 255 - assert saved["user_id"] == "user-1" - - update_payload = { - "conversation_id": "conv-1", - "dialog_id": "dialog-1", - "is_new": False, - "name": "rename", - } - _set_request_json(monkeypatch, module, update_payload) - monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: False) - res = _run(module.set_conversation()) - assert "Conversation not found" in res["message"] - - _set_request_json(monkeypatch, module, update_payload) - monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None)) - res = _run(module.set_conversation()) - assert "Fail to update" in res["message"] - - _set_request_json(monkeypatch, module, update_payload) - monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, _DummyConversation(conv_id="conv-1"))) - res = _run(module.set_conversation()) - assert res["code"] == 0 - assert res["data"]["id"] == "conv-1" - - _set_request_json(monkeypatch, module, update_payload) - - def _raise_update(*_args, **_kwargs): - raise RuntimeError("update boom") - - monkeypatch.setattr(module.ConversationService, "update_by_id", _raise_update) - res = _run(module.set_conversation()) - assert res["code"] == module.RetCode.EXCEPTION_ERROR - assert "update boom" in res["message"] - - missing_dialog_payload = { - "conversation_id": "conv-2", - "dialog_id": "dialog-missing", - "is_new": True, - "name": "create", - } - _set_request_json(monkeypatch, module, missing_dialog_payload) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None)) - res = _run(module.set_conversation()) - assert res["message"] == "Dialog not found" - - _set_request_json(monkeypatch, module, missing_dialog_payload) - - def _raise_dialog(_id): - raise RuntimeError("dialog boom") - - monkeypatch.setattr(module.DialogService, "get_by_id", _raise_dialog) - res = _run(module.set_conversation()) - assert res["code"] == module.RetCode.EXCEPTION_ERROR - assert "dialog boom" in res["message"] - - -@pytest.mark.p2 -def test_get_and_getsse_authorization_and_reference_paths(monkeypatch): - module = _load_conversation_module(monkeypatch) - - conv = _DummyConversation(reference=[{"doc": "d"}, ["already-formatted"]]) - monkeypatch.setattr(module, "request", _DummyRequest(args={"conversation_id": "conv-1"})) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv)) - monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")]) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(icon="bot-avatar")]) - monkeypatch.setattr(module, "chunks_format", lambda _ref: [{"chunk": "normalized"}]) - - res = _run(module.get()) - assert res["code"] == 0 - assert res["data"]["avatar"] == "bot-avatar" - assert res["data"]["reference"][0]["chunks"] == [{"chunk": "normalized"}] - - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None)) - res = _run(module.get()) - assert res["message"] == "Conversation not found!" - - monkeypatch.setattr(module, "request", _DummyRequest(args={"conversation_id": "conv-1"})) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv)) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - res = _run(module.get()) - assert res["code"] == module.RetCode.OPERATING_ERROR - assert "Only owner of conversation" in res["message"] - - def _raise_get(*_args, **_kwargs): - raise RuntimeError("get boom") - - monkeypatch.setattr(module.ConversationService, "get_by_id", _raise_get) - res = _run(module.get()) - assert res["code"] == module.RetCode.EXCEPTION_ERROR - assert "get boom" in res["message"] - - monkeypatch.setattr(module, "request", _DummyRequest(headers={"Authorization": "Bearer"})) - res = module.getsse("dialog-1") - assert "Authorization is not valid" in res["message"] - - monkeypatch.setattr(module, "request", _DummyRequest(headers={"Authorization": "Bearer token-1"})) - monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: []) - res = module.getsse("dialog-1") - assert "API key is invalid" in res["message"] - - monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace()]) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None)) - res = module.getsse("dialog-1") - assert res["message"] == "Dialog not found!" - - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialog())) - res = module.getsse("dialog-1") - assert res["code"] == 0 - assert res["data"]["avatar"] == "avatar.png" - assert "icon" not in res["data"] - - def _raise_getsse(_id): - raise RuntimeError("getsse boom") - - monkeypatch.setattr(module.DialogService, "get_by_id", _raise_getsse) - res = module.getsse("dialog-1") - assert res["code"] == module.RetCode.EXCEPTION_ERROR - assert "getsse boom" in res["message"] - - -@pytest.mark.p2 -def test_rm_and_list_conversation_guards(monkeypatch): - module = _load_conversation_module(monkeypatch) - - _set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]}) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None)) - res = _run(module.rm()) - assert "Conversation not found" in res["message"] - - conv = _DummyConversation(conv_id="conv-1", dialog_id="dialog-1") - _set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]}) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv)) - monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")]) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - res = _run(module.rm()) - assert res["code"] == module.RetCode.OPERATING_ERROR - - deleted = [] - _set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]}) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="dialog-1")]) - monkeypatch.setattr(module.ConversationService, "delete_by_id", lambda cid: deleted.append(cid) or True) - res = _run(module.rm()) - assert res["code"] == 0 - assert res["data"] is True - assert deleted == ["conv-1"] - - _set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]}) - - def _raise_rm(*_args, **_kwargs): - raise RuntimeError("rm boom") - - monkeypatch.setattr(module.ConversationService, "get_by_id", _raise_rm) - res = _run(module.rm()) - assert res["code"] == module.RetCode.EXCEPTION_ERROR - assert "rm boom" in res["message"] - - monkeypatch.setattr(module, "request", _DummyRequest(args={"dialog_id": "dialog-1"})) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) - res = _run(module.list_conversation()) - assert res["code"] == module.RetCode.OPERATING_ERROR - assert "Only owner of dialog" in res["message"] - - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="dialog-1")]) - monkeypatch.setattr(module.ConversationService, "model", SimpleNamespace(create_time="create_time")) - monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: [_DummyConversation(conv_id="c1"), _DummyConversation(conv_id="c2")]) - res = _run(module.list_conversation()) - assert res["code"] == 0 - assert [x["id"] for x in res["data"]] == ["c1", "c2"] - - def _raise_list(**_kwargs): - raise RuntimeError("list boom") - - monkeypatch.setattr(module.ConversationService, "query", _raise_list) - res = _run(module.list_conversation()) - assert res["code"] == module.RetCode.EXCEPTION_ERROR - assert "list boom" in res["message"] - - -@pytest.mark.p2 -def test_completion_stream_and_nonstream_branches(monkeypatch): - module = _load_conversation_module(monkeypatch) - - conv = _DummyConversation(conv_id="conv-1", dialog_id="dialog-1", reference=[]) - dia = _DummyDialog(dialog_id="dialog-1", tenant_id="tenant-1") - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv)) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, dia)) - monkeypatch.setattr(module, "structure_answer", lambda _conv, ans, message_id, conv_id: {"answer": ans["answer"], "id": message_id, "conversation_id": conv_id, "reference": []}) - - updates = [] - monkeypatch.setattr(module.ConversationService, "update_by_id", lambda conv_id, payload: updates.append((conv_id, payload)) or True) - - stream_payload = { - "conversation_id": "conv-1", - "messages": [ - {"role": "system", "content": "ignored"}, - {"role": "assistant", "content": "ignored-first-assistant"}, - {"role": "user", "content": "hello", "id": "m-1"}, - ], - "stream": True, - } - - async def _stream_ok(_dia, sanitized, *_args, **_kwargs): - assert [m["role"] for m in sanitized] == ["user"] - yield {"answer": "sse-ok"} - - monkeypatch.setattr(module, "async_chat", _stream_ok) - _set_request_json(monkeypatch, module, stream_payload) - resp = _run(module.completion.__wrapped__()) - assert resp.headers["Content-Type"].startswith("text/event-stream") - sse_text = _run(_read_sse_text(resp)) - assert "sse-ok" in sse_text - assert '"data": true' in sse_text - assert updates - - async def _stream_error(_dia, _sanitized, *_args, **_kwargs): - raise RuntimeError("stream explode") - if False: - yield {"answer": "never"} - - monkeypatch.setattr(module, "async_chat", _stream_error) - _set_request_json(monkeypatch, module, stream_payload) - resp = _run(module.completion.__wrapped__()) - sse_text = _run(_read_sse_text(resp)) - assert "**ERROR**: stream explode" in sse_text - - async def _non_stream(_dia, _sanitized, **_kwargs): - yield {"answer": "plain-ok"} - - monkeypatch.setattr(module, "async_chat", _non_stream) - _set_request_json( - monkeypatch, - module, - { - "conversation_id": "conv-1", - "messages": [{"role": "user", "content": "plain", "id": "m-2"}], - "stream": False, - }, - ) - res = _run(module.completion.__wrapped__()) - assert res["code"] == 0 - assert res["data"]["answer"] == "plain-ok" - - monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda **_kwargs: False) - _set_request_json( - monkeypatch, - module, - { - "conversation_id": "conv-1", - "messages": [{"role": "user", "content": "embed", "id": "m-3"}], - "llm_id": "bad-model", - "stream": False, - }, - ) - res = _run(module.completion.__wrapped__()) - assert "Cannot use specified model bad-model" in res["message"] - - monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda **_kwargs: "api-key") - _set_request_json( - monkeypatch, - module, - { - "conversation_id": "conv-1", - "messages": [{"role": "user", "content": "embed", "id": "m-4"}], - "llm_id": "glm-4", - "temperature": 0.7, - "top_p": 0.2, - "stream": False, - }, - ) - res = _run(module.completion.__wrapped__()) - assert res["code"] == 0 - assert dia.llm_id == "glm-4" - assert dia.llm_setting == {"temperature": 0.7, "top_p": 0.2} - - _set_request_json( - monkeypatch, - module, - { - "conversation_id": "missing", - "messages": [{"role": "user", "content": "x", "id": "m-5"}], - "stream": False, - }, - ) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None)) - res = _run(module.completion.__wrapped__()) - assert res["message"] == "Conversation not found!" - - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv)) - monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None)) - _set_request_json( - monkeypatch, - module, - { - "conversation_id": "conv-1", - "messages": [{"role": "user", "content": "x", "id": "m-6"}], - "stream": False, - }, - ) - res = _run(module.completion.__wrapped__()) - assert res["message"] == "Dialog not found!" - - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (_ for _ in ()).throw(RuntimeError("completion boom"))) - _set_request_json( - monkeypatch, - module, - { - "conversation_id": "conv-1", - "messages": [{"role": "user", "content": "x", "id": "m-7"}], - "stream": False, - }, - ) - res = _run(module.completion.__wrapped__()) - assert res["code"] == module.RetCode.EXCEPTION_ERROR - assert "completion boom" in res["message"] - - -@pytest.mark.p2 -def test_sequence2txt_validation_and_transcription_paths(monkeypatch): - module = _load_conversation_module(monkeypatch) - - monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={})) - res = _run(module.sequence2txt()) - assert "Missing 'file'" in res["message"] - - bad_file = _DummyUploadedFile("audio.txt") - monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": bad_file})) - res = _run(module.sequence2txt()) - assert "Unsupported audio format" in res["message"] - - wav_file = _DummyUploadedFile("audio.wav") - monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file})) - monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found"))) - res = _run(module.sequence2txt()) - assert res["message"] == "Tenant not found" - - wav_file = _DummyUploadedFile("audio.wav") - monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file})) - monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default speech2text model is set."))) - res = _run(module.sequence2txt()) - assert res["message"] == "No default speech2text model is set." - - class _SyncAsr: - def transcription(self, _path): - return "transcribed text" - - def stream_transcription(self, _path): - return [] - - wav_file = _DummyUploadedFile("audio.wav") - monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file})) - monkeypatch.setattr( - module, - "get_tenant_default_model_by_type", - lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "asr-model", "model_type": module.LLMType.SPEECH2TEXT.value}, - ) - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncAsr()) - monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("remove failed"))) - res = _run(module.sequence2txt()) - assert res["code"] == 0 - assert res["data"]["text"] == "transcribed text" - - class _StreamAsr: - def transcription(self, _path): - return "" - - def stream_transcription(self, _path): - yield {"event": "partial", "text": "hello"} - - wav_file = _DummyUploadedFile("audio.wav") - monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "true"}, files={"file": wav_file})) - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _StreamAsr()) - resp = _run(module.sequence2txt()) - assert resp.headers["Content-Type"].startswith("text/event-stream") - sse_text = _run(_read_sse_text(resp)) - assert '"event": "partial"' in sse_text - - class _ErrorStreamAsr: - def transcription(self, _path): - return "" - - def stream_transcription(self, _path): - raise RuntimeError("stream asr boom") - - wav_file = _DummyUploadedFile("audio.wav") - monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "true"}, files={"file": wav_file})) - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _ErrorStreamAsr()) - resp = _run(module.sequence2txt()) - sse_text = _run(_read_sse_text(resp)) - assert "stream asr boom" in sse_text - - -@pytest.mark.p2 -def test_tts_request_parse_entry(monkeypatch): - module = _load_conversation_module(monkeypatch) - _set_request_json(monkeypatch, module, {"text": "A。B"}) - monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found"))) - res = _run(module.tts()) - assert res["message"] == "Tenant not found" - - monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default tts model is set."))) - res = _run(module.tts()) - assert res["message"] == "No default tts model is set." - - class _TTSOk: - def tts(self, txt): - if not txt: - return [] - yield f"chunk-{txt}".encode("utf-8") - - monkeypatch.setattr( - module, - "get_tenant_default_model_by_type", - lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "tts-x", "model_type": module.LLMType.TTS.value}, - ) - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk()) - resp = _run(module.tts()) - assert resp.mimetype == "audio/mpeg" - assert resp.headers.get("Cache-Control") == "no-cache" - assert resp.headers.get("Connection") == "keep-alive" - assert resp.headers.get("X-Accel-Buffering") == "no" - stream_text = _run(_read_sse_text(resp)) - assert "chunk-A" in stream_text - assert "chunk-B" in stream_text - - class _TTSErr: - def tts(self, _txt): - raise RuntimeError("tts boom") - - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSErr()) - resp = _run(module.tts()) - stream_text = _run(_read_sse_text(resp)) - assert '"code": 500' in stream_text - assert "**ERROR**: tts boom" in stream_text - - -@pytest.mark.p2 -def test_delete_msg_and_thumbup_matrix_unit(monkeypatch): - module = _load_conversation_module(monkeypatch) - - updates = [] - monkeypatch.setattr(module.ConversationService, "update_by_id", lambda conv_id, payload: updates.append((conv_id, payload)) or True) - - _set_request_json(monkeypatch, module, {"conversation_id": "missing", "message_id": "pair-1"}) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None)) - res = _run(module.delete_msg.__wrapped__()) - assert res["message"] == "Conversation not found!" - - conv = _DummyConversation( - conv_id="conv-del", - message=[ - {"id": "other", "role": "user"}, - {"id": "pair-1", "role": "user"}, - {"id": "pair-1", "role": "assistant"}, - ], - reference=[{"chunks": [{"id": "c1"}]}], - ) - _set_request_json(monkeypatch, module, {"conversation_id": "conv-del", "message_id": "pair-1"}) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv)) - res = _run(module.delete_msg.__wrapped__()) - assert res["code"] == 0 - assert [m["id"] for m in res["data"]["message"]] == ["other"] - assert res["data"]["reference"] == [] - assert updates[-1][0] == "conv-del" - - _set_request_json(monkeypatch, module, {"conversation_id": "missing", "message_id": "assistant-1", "thumbup": True}) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None)) - res = _run(module.thumbup.__wrapped__()) - assert res["message"] == "Conversation not found!" - - conv_up = _DummyConversation( - conv_id="conv-up", - message=[{"id": "assistant-1", "role": "assistant", "feedback": "old"}], - ) - _set_request_json(monkeypatch, module, {"conversation_id": "conv-up", "message_id": "assistant-1", "thumbup": True}) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv_up)) - res = _run(module.thumbup.__wrapped__()) - assert res["code"] == 0 - assert res["data"]["message"][0]["thumbup"] is True - assert "feedback" not in res["data"]["message"][0] - - conv_down = _DummyConversation(conv_id="conv-down", message=[{"id": "assistant-2", "role": "assistant"}]) - _set_request_json( - monkeypatch, - module, - {"conversation_id": "conv-down", "message_id": "assistant-2", "thumbup": False, "feedback": "needs sources"}, - ) - monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv_down)) - res = _run(module.thumbup.__wrapped__()) - assert res["code"] == 0 - assert res["data"]["message"][0]["thumbup"] is False - assert res["data"]["message"][0]["feedback"] == "needs sources" - - -@pytest.mark.p2 -def test_ask_about_stream_search_config_matrix_unit(monkeypatch): - module = _load_conversation_module(monkeypatch) - _set_request_json(monkeypatch, module, {"question": "q", "kb_ids": ["kb-1"], "search_id": "search-1"}) - monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"mode": "test"}}) - - captured = {} - - async def _fake_async_ask(question, kb_ids, uid, search_config=None): - captured["question"] = question - captured["kb_ids"] = kb_ids - captured["uid"] = uid - captured["search_config"] = search_config - yield {"answer": "first"} - raise RuntimeError("ask boom") - - monkeypatch.setattr(module, "async_ask", _fake_async_ask) - resp = _run(module.ask_about.__wrapped__()) - assert resp.headers["Content-Type"] == "text/event-stream; charset=utf-8" - sse_text = _run(_read_sse_text(resp)) - assert '"answer": "first"' in sse_text - assert "**ERROR**: ask boom" in sse_text - assert '"data": true' in sse_text.lower() - assert captured == {"question": "q", "kb_ids": ["kb-1"], "uid": "user-1", "search_config": {"mode": "test"}} - - -@pytest.mark.p2 -def test_mindmap_and_related_questions_matrix_unit(monkeypatch): - module = _load_conversation_module(monkeypatch) - - def _search_detail(_sid): - return { - "tenant_id": "tenant-x", - "search_config": { - "kb_ids": ["kb-2", "kb-3"], - "chat_id": "chat-x", - "llm_setting": {"temperature": 0.2, "parameter": {"k": "v"}}, - }, - } - - monkeypatch.setattr(module.SearchService, "get_detail", _search_detail) - - _set_request_json(monkeypatch, module, {"question": "mindmap-q", "kb_ids": ["kb-1", "kb-2"], "search_id": "search-1"}) - mindmap_calls = {} - - async def _gen_ok(question, kb_ids, tenant_id, search_config): - mindmap_calls["question"] = question - mindmap_calls["kb_ids"] = set(kb_ids) - mindmap_calls["tenant_id"] = tenant_id - mindmap_calls["search_config"] = search_config - return {"nodes": [question]} - - monkeypatch.setattr(module, "gen_mindmap", _gen_ok) - res = _run(module.mindmap.__wrapped__()) - assert res["code"] == 0 - assert res["data"] == {"nodes": ["mindmap-q"]} - assert mindmap_calls["kb_ids"] == {"kb-1", "kb-2", "kb-3"} - assert mindmap_calls["tenant_id"] == "tenant-x" - assert set(mindmap_calls["search_config"]["kb_ids"]) == {"kb-1", "kb-2", "kb-3"} - - async def _gen_error(*_args, **_kwargs): - return {"error": "mindmap boom"} - - monkeypatch.setattr(module, "gen_mindmap", _gen_error) - res = _run(module.mindmap.__wrapped__()) - assert "mindmap boom" in res["message"] - - llm_calls = {} - - class _FakeChat: - async def async_chat(self, prompt, messages, options): - llm_calls["prompt"] = prompt - llm_calls["messages"] = messages - llm_calls["options"] = options - return "1. Alpha\n2. Beta\nignored" - - def _fake_bundle(tenant_id, model_config, lang="Chinese", **kwargs): - llm_calls["bundle"] = (tenant_id, model_config) - return _FakeChat() - - monkeypatch.setattr(module, "LLMBundle", _fake_bundle) - monkeypatch.setattr(module, "load_prompt", lambda name: f"prompt-{name}") - monkeypatch.setattr( - module, - "get_model_config_by_type_and_name", - lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "chat-x", "model_type": module.LLMType.CHAT.value}, - ) - _set_request_json(monkeypatch, module, {"question": "solar", "search_id": "search-1"}) - res = _run(module.related_questions.__wrapped__()) - assert res["code"] == 0 - assert res["data"] == ["Alpha", "Beta"] - assert llm_calls["bundle"][0] == "user-1" - assert llm_calls["options"] == {"temperature": 0.2} - assert llm_calls["prompt"] == "prompt-related_question" - assert "Keywords: solar" in llm_calls["messages"][0]["content"] diff --git a/web/src/components/originui/select-with-search.tsx b/web/src/components/originui/select-with-search.tsx index 840284042..8b55d1f8b 100644 --- a/web/src/components/originui/select-with-search.tsx +++ b/web/src/components/originui/select-with-search.tsx @@ -2,7 +2,6 @@ import { CheckIcon, ChevronDownIcon, XIcon } from 'lucide-react'; import { - Fragment, MouseEventHandler, ReactNode, forwardRef, @@ -207,40 +206,43 @@ export const SelectWithSearch = forwardRef<
- {options.map((group) => { + {options.map((group, groupIndex) => { if (group.options) { return ( - - - {group.options.map((option) => ( - - {option.label} + + {group.options.map((option, optionIndex) => ( + + {option.label} - {value === option.value && ( - - )} - - ))} - - + {value === option.value && ( + + )} + + ))} + ); } else { return ( , + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +SelectGroup.displayName = SelectPrimitive.Group.displayName; const SelectValue = SelectPrimitive.Value; diff --git a/web/src/hooks/logic-hooks.ts b/web/src/hooks/logic-hooks.ts index dfdea1f5e..55981fbc9 100644 --- a/web/src/hooks/logic-hooks.ts +++ b/web/src/hooks/logic-hooks.ts @@ -201,9 +201,7 @@ function useSetDoneRecord() { }; } -export const useSendMessageWithSse = ( - url: string = api.completeConversation, -) => { +export const useSendMessageWithSse = () => { const [answer, setAnswer] = useState({} as IAnswer); const [done, setDone] = useState(true); const { doneRecord, clearDoneRecord, setDoneRecordById, allDone } = @@ -238,6 +236,7 @@ export const useSendMessageWithSse = ( const send = useCallback( async ( + url: string, body: any, controller?: AbortController, ): Promise<{ response: Response; data: ResponseType } | undefined> => { @@ -322,7 +321,7 @@ export const useSendMessageWithSse = ( // Swallow fetch errors silently } }, - [initializeSseRef, setDoneValue, url, resetAnswer], + [initializeSseRef, setDoneValue, resetAnswer], ); const stopOutputMessage = useCallback(() => { @@ -342,7 +341,7 @@ export const useSendMessageWithSse = ( }; }; -export const useSpeechWithSse = (url: string = api.tts) => { +export const useSpeechWithSse = (url: string = api.chatsTts) => { const read = useCallback( async (body: any) => { const response = await fetch(url, { diff --git a/web/src/hooks/use-chat-request.ts b/web/src/hooks/use-chat-request.ts index 149d26c26..145111f4f 100644 --- a/web/src/hooks/use-chat-request.ts +++ b/web/src/hooks/use-chat-request.ts @@ -13,10 +13,9 @@ import { } from '@/interfaces/request/chat'; import i18n from '@/locales/config'; import { useGetSharedChatSearchParams } from '@/pages/next-chats/hooks/use-send-shared-message'; -import { isConversationIdExist } from '@/pages/next-chats/utils'; import chatService from '@/services/next-chat-service'; import api from '@/utils/api'; -import { buildMessageListWithUuid, generateConversationId } from '@/utils/chat'; +import { buildMessageListWithUuid } from '@/utils/chat'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { useDebounce } from 'ahooks'; import { has } from 'lodash'; @@ -36,11 +35,12 @@ export const enum ChatApiAction { UpdateChat = 'updateChat', PatchChat = 'patchChat', FetchChat = 'fetchChat', - FetchConversationList = 'fetchConversationList', - FetchConversation = 'fetchConversation', - FetchConversationManually = 'fetchConversationManually', - UpdateConversation = 'updateConversation', - RemoveConversation = 'removeConversation', + FetchSessionList = 'fetchSessionList', + FetchSession = 'fetchSession', + FetchSessionManually = 'fetchSessionManually', + CreateSession = 'createSession', + UpdateSession = 'updateSession', + RemoveSession = 'removeSession', DeleteMessage = 'deleteMessage', FetchMindMap = 'fetchMindMap', FetchRelatedQuestions = 'fetchRelatedQuestions', @@ -48,7 +48,6 @@ export const enum ChatApiAction { FetchExternalChatInfo = 'fetchExternalChatInfo', Feedback = 'feedback', CreateSharedConversation = 'createSharedConversation', - FetchConversationSse = 'fetchConversationSSE', } export const useGetChatSearchParams = () => { @@ -262,9 +261,9 @@ export const useFetchChat = () => { return { data, loading, refetch }; }; -//#region Conversation +//#region Session -export const useFetchConversationList = () => { +export const useFetchSessionList = () => { const { id } = useParams(); const { searchString, handleInputChange } = useHandleSearchStrChange(); @@ -274,7 +273,7 @@ export const useFetchConversationList = () => { isFetching: loading, refetch, } = useQuery({ - queryKey: [ChatApiAction.FetchConversationList, id], + queryKey: [ChatApiAction.FetchSessionList, id], initialData: [], gcTime: 0, refetchOnWindowFocus: false, @@ -285,8 +284,8 @@ export const useFetchConversationList = () => { : data; }, queryFn: async () => { - const { data } = await chatService.listConversation( - { params: { dialog_id: id } }, + const { data } = await chatService.listSessions( + { url: api.listSessions(id!) }, true, ); return data?.data; @@ -296,35 +295,57 @@ export const useFetchConversationList = () => { return { data, loading, refetch, searchString, handleInputChange }; }; -export function useFetchConversationManually() { +export function useFetchSessionManually() { + const { id: chatId } = useParams(); const { data, isPending: loading, mutateAsync, } = useMutation({ - mutationKey: [ChatApiAction.FetchConversationManually], - mutationFn: async (conversationId) => { - const { data } = await chatService.getConversation( - { - params: { - conversationId, - }, - }, + mutationKey: [ChatApiAction.FetchSessionManually], + mutationFn: async (sessionId) => { + const { data } = await chatService.getSession( + { url: api.getSession(chatId!, sessionId) }, true, ); const conversation = data?.data ?? {}; - const messageList = buildMessageListWithUuid(conversation?.message); + const messageList = buildMessageListWithUuid(conversation?.messages); - return { ...conversation, message: messageList }; + return { ...conversation, messages: messageList }; }, }); - return { data, loading, fetchConversationManually: mutateAsync }; + return { data, loading, fetchSessionManually: mutateAsync }; } -export const useUpdateConversation = () => { +export const useCreateSession = () => { + const queryClient = useQueryClient(); + const { + data, + isPending: loading, + mutateAsync, + } = useMutation({ + mutationKey: [ChatApiAction.CreateSession], + mutationFn: async ({ chatId, name }: { chatId: string; name: string }) => { + const { data } = await chatService.createSession( + { url: api.createSession(chatId), data: { name } }, + true, + ); + if (data.code === 0) { + queryClient.invalidateQueries({ + queryKey: [ChatApiAction.FetchSessionList], + }); + } + return data; + }, + }); + + return { data, loading, createSession: mutateAsync }; +}; + +export const useUpdateSession = () => { const { t } = useTranslation(); const queryClient = useQueryClient(); const { @@ -332,17 +353,23 @@ export const useUpdateConversation = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: [ChatApiAction.UpdateConversation], - mutationFn: async (params: Record) => { - const { data } = await chatService.setConversation({ - ...params, - conversation_id: params.conversation_id - ? params.conversation_id - : generateConversationId(), - }); + mutationKey: [ChatApiAction.UpdateSession], + mutationFn: async ({ + chatId, + sessionId, + params, + }: { + chatId: string; + sessionId: string; + params: Record; + }) => { + const { data } = await chatService.updateSession( + { url: api.updateSession(chatId, sessionId), data: params }, + true, + ); if (data.code === 0) { queryClient.invalidateQueries({ - queryKey: [ChatApiAction.FetchConversationList], + queryKey: [ChatApiAction.FetchSessionList], }); message.success(t(`message.modified`)); } @@ -350,38 +377,39 @@ export const useUpdateConversation = () => { }, }); - return { data, loading, updateConversation: mutateAsync }; + return { data, loading, updateSession: mutateAsync }; }; -export const useRemoveConversation = () => { +export const useRemoveSessions = () => { const queryClient = useQueryClient(); - const { dialogId } = useGetChatSearchParams(); + const { id: chatId } = useParams(); const { data, isPending: loading, mutateAsync, } = useMutation({ - mutationKey: [ChatApiAction.RemoveConversation], - mutationFn: async (conversationIds: string[]) => { - const { data } = await chatService.removeConversation({ - conversationIds, - dialogId, - }); + mutationKey: [ChatApiAction.RemoveSession], + mutationFn: async (sessionIds: string[]) => { + const { data } = await chatService.removeSessions( + { url: api.removeSessions(chatId!), data: { ids: sessionIds } }, + true, + ); if (data.code === 0) { queryClient.invalidateQueries({ - queryKey: [ChatApiAction.FetchConversationList], + queryKey: [ChatApiAction.FetchSessionList], }); } return data.code; }, }); - return { data, loading, removeConversation: mutateAsync }; + return { data, loading, removeSessions: mutateAsync }; }; export const useDeleteMessage = () => { const { conversationId } = useGetChatSearchParams(); + const { id: chatId } = useParams(); const { t } = useTranslation(); const { @@ -391,10 +419,10 @@ export const useDeleteMessage = () => { } = useMutation({ mutationKey: [ChatApiAction.DeleteMessage], mutationFn: async (messageId: string) => { - const { data } = await chatService.deleteMessage({ - messageId, - conversationId, - }); + const { data } = await chatService.deleteMessage( + { url: api.deleteMessage(chatId!, conversationId, messageId) }, + true, + ); if (data.code === 0) { message.success(t(`message.deleted`)); @@ -409,6 +437,7 @@ export const useDeleteMessage = () => { export const useFeedback = () => { const { conversationId } = useGetChatSearchParams(); + const { id: chatId } = useParams(); const { data, @@ -417,10 +446,13 @@ export const useFeedback = () => { } = useMutation({ mutationKey: [ChatApiAction.Feedback], mutationFn: async (params: IFeedbackRequestBody) => { - const { data } = await chatService.thumbup({ - ...params, - conversationId, - }); + const { data } = await chatService.thumbup( + { + url: api.thumbup(chatId!, conversationId, params.messageId!), + data: { thumbup: params.thumbup, feedback: params.feedback }, + }, + true, + ); if (data.code === 0) { message.success(i18n.t(`message.operated`)); } @@ -519,7 +551,7 @@ export const useFetchExternalChatInfo = () => { return { data, loading, refetch }; }; -//#endregion +//#endregion Session //#region search page @@ -533,7 +565,7 @@ export const useFetchMindMap = () => { gcTime: 0, mutationFn: async (params: IAskRequestBody) => { try { - const ret = await chatService.getMindMap(params); + const ret = await chatService.chatsMindmap(params); return ret?.data?.data ?? {}; } catch (error: any) { if (has(error, 'message')) { @@ -557,7 +589,7 @@ export const useFetchRelatedQuestions = () => { mutationKey: [ChatApiAction.FetchRelatedQuestions], gcTime: 0, mutationFn: async (question: string): Promise => { - const { data } = await chatService.getRelatedQuestions({ question }); + const { data } = await chatService.chatsRelatedQuestions({ question }); return data?.data ?? []; }, @@ -566,47 +598,3 @@ export const useFetchRelatedQuestions = () => { return { data, loading, fetchRelatedQuestions: mutateAsync }; }; //#endregion - -export const useCreateNextSharedConversation = () => { - const { - data, - isPending: loading, - mutateAsync, - } = useMutation({ - mutationKey: [ChatApiAction.CreateSharedConversation], - mutationFn: async (userId?: string) => { - const { data } = await chatService.createExternalConversation({ userId }); - - return data; - }, - }); - - return { data, loading, createSharedConversation: mutateAsync }; -}; - -export const useFetchNextConversationSSE = () => { - const { isNew } = useGetChatSearchParams(); - const { sharedId } = useGetSharedChatSearchParams(); - const { - data, - isFetching: loading, - refetch, - } = useQuery({ - queryKey: [ChatApiAction.FetchConversationSse, sharedId], - initialData: {} as IClientConversation, - gcTime: 0, - refetchOnWindowFocus: false, - queryFn: async () => { - if (isNew !== 'true' && isConversationIdExist(sharedId || '')) { - if (!sharedId) return {}; - const { data } = await chatService.getConversationSSE(sharedId); - const conversation = data?.data ?? {}; - const messageList = buildMessageListWithUuid(conversation?.message); - return { ...conversation, message: messageList }; - } - return { message: [] }; - }, - }); - - return { data, loading, refetch }; -}; diff --git a/web/src/hooks/use-send-message.ts b/web/src/hooks/use-send-message.ts index 50848b3dc..8ebcade20 100644 --- a/web/src/hooks/use-send-message.ts +++ b/web/src/hooks/use-send-message.ts @@ -2,7 +2,6 @@ import message from '@/components/ui/message'; import { Authorization } from '@/constants/authorization'; import { IReferenceObject } from '@/interfaces/database/chat'; import { BeginQuery } from '@/pages/agent/interface'; -import api from '@/utils/api'; import { getAuthorization } from '@/utils/authorization-util'; import { EventSourceParserStream } from 'eventsource-parser/stream'; import { useCallback, useRef, useState } from 'react'; @@ -86,7 +85,7 @@ export type IChatEvent = INodeEvent | IMessageEvent | IMessageEndEvent; export type IEventList = Array; -export const useSendMessageBySSE = (url: string = api.completeConversation) => { +export const useSendMessageBySSE = (url: string) => { const [answerList, setAnswerList] = useState([]); const [done, setDone] = useState(true); const timer = useRef(); diff --git a/web/src/interfaces/database/chat.ts b/web/src/interfaces/database/chat.ts index cb879456e..eeb298fc1 100644 --- a/web/src/interfaces/database/chat.ts +++ b/web/src/interfaces/database/chat.ts @@ -82,10 +82,10 @@ interface Manual { export interface IConversation { create_date: string; create_time: number; - dialog_id: string; + chat_id: string; id: string; avatar: string; - message: Message[]; + messages: Message[]; reference: IReference[]; name: string; update_date: string; @@ -197,7 +197,7 @@ export interface IMessage extends Message { } export interface IClientConversation extends IConversation { - message: IMessage[]; + messages: IMessage[]; } export interface UploadResponseDataType { diff --git a/web/src/pages/agent/utils/chat.ts b/web/src/pages/agent/utils/chat.ts index a2859b6dc..369cb5aa4 100644 --- a/web/src/pages/agent/utils/chat.ts +++ b/web/src/pages/agent/utils/chat.ts @@ -3,10 +3,10 @@ import { IMessage, IReference } from '@/interfaces/database/chat'; import { isEmpty } from 'lodash'; export const buildAgentMessageItemReference = ( - conversation: { message: IMessage[]; reference: IReference[] }, + conversation: { messages: IMessage[]; reference: IReference[] }, message: IMessage, ) => { - const assistantMessages = conversation.message?.filter( + const assistantMessages = conversation.messages?.filter( (x) => x.role === MessageType.Assistant, ); const referenceIndex = assistantMessages.findIndex( diff --git a/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx b/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx index 089a8face..071499271 100644 --- a/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx +++ b/web/src/pages/next-chats/chat/chat-box/next-multiple-chat-box.tsx @@ -240,7 +240,7 @@ const ChatCard = forwardRef(function ChatCard( avatarDialog={currentDialog.icon} reference={buildMessageItemReference( { - message: derivedMessages, + messages: derivedMessages, reference: conversation.reference, }, message, diff --git a/web/src/pages/next-chats/chat/chat-box/single-chat-box.tsx b/web/src/pages/next-chats/chat/chat-box/single-chat-box.tsx index 737aaf059..625441d35 100644 --- a/web/src/pages/next-chats/chat/chat-box/single-chat-box.tsx +++ b/web/src/pages/next-chats/chat/chat-box/single-chat-box.tsx @@ -56,11 +56,11 @@ export function SingleChatBox({ const showInternet = useShowInternet(); useEffect(() => { - const messages = conversation?.message; + const messages = conversation?.messages; if (Array.isArray(messages)) { setDerivedMessages(messages); } - }, [conversation?.message, setDerivedMessages]); + }, [conversation?.messages, setDerivedMessages]); useEffect(() => { // Clear the message list after deleting the conversation. @@ -90,7 +90,7 @@ export function SingleChatBox({ avatarDialog={currentDialog.icon} reference={buildMessageItemReference( { - message: derivedMessages, + messages: derivedMessages, reference: conversation.reference, }, message, diff --git a/web/src/pages/next-chats/chat/conversation-dropdown.tsx b/web/src/pages/next-chats/chat/conversation-dropdown.tsx index 8d8098660..6c1db96a9 100644 --- a/web/src/pages/next-chats/chat/conversation-dropdown.tsx +++ b/web/src/pages/next-chats/chat/conversation-dropdown.tsx @@ -7,7 +7,7 @@ import { } from '@/components/ui/dropdown-menu'; import { useGetChatSearchParams, - useRemoveConversation, + useRemoveSessions, } from '@/hooks/use-chat-request'; import { IConversation } from '@/interfaces/database/chat'; import { Trash2 } from 'lucide-react'; @@ -25,7 +25,7 @@ export function ConversationDropdown({ }) { const { t } = useTranslation(); const { setConversationBoth } = useChatUrlParams(); - const { removeConversation } = useRemoveConversation(); + const { removeSessions } = useRemoveSessions(); const { conversationId, isNew } = useGetChatSearchParams(); const handleDelete: MouseEventHandler = @@ -36,7 +36,7 @@ export function ConversationDropdown({ setConversationBoth('', ''); } } else { - const code = await removeConversation([conversation.id]); + const code = await removeSessions([conversation.id]); if (code === 0) { setConversationBoth('', ''); } @@ -45,7 +45,7 @@ export function ConversationDropdown({ conversation.id, conversationId, isNew, - removeConversation, + removeSessions, removeTemporaryConversation, setConversationBoth, ]); diff --git a/web/src/pages/next-chats/chat/index.tsx b/web/src/pages/next-chats/chat/index.tsx index 62d8f2051..a1b3367df 100644 --- a/web/src/pages/next-chats/chat/index.tsx +++ b/web/src/pages/next-chats/chat/index.tsx @@ -1,8 +1,8 @@ import { Button } from '@/components/ui/button'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { - useFetchConversationList, - useFetchConversationManually, + useFetchSessionList, + useFetchSessionManually, useGetChatSearchParams, } from '@/hooks/use-chat-request'; import { IClientConversation } from '@/interfaces/database/chat'; @@ -26,7 +26,7 @@ export default function Chat() { const [currentConversation, setCurrentConversation] = useState({} as IClientConversation); - const { fetchConversationManually } = useFetchConversationManually(); + const { fetchSessionManually } = useFetchSessionManually(); const { handleConversationCardClick, controller, stopOutputMessage } = useHandleClickConversationCard(); @@ -37,7 +37,7 @@ export default function Chat() { const { conversationId, isNew } = useGetChatSearchParams(); - const { data: dialogList } = useFetchConversationList(); + const { data: dialogList } = useFetchSessionList(); const currentConversationName = useMemo(() => { return ( @@ -49,13 +49,13 @@ export default function Chat() { const fetchConversation: typeof handleConversationCardClick = useCallback( async (conversationId, isNew) => { if (conversationId && !isNew) { - const conversation = await fetchConversationManually(conversationId); + const conversation = await fetchSessionManually(conversationId); if (!isEmpty(conversation)) { setCurrentConversation(conversation); } } }, - [fetchConversationManually], + [fetchSessionManually], ); const handleSessionClick: typeof handleConversationCardClick = useCallback( diff --git a/web/src/pages/next-chats/chat/sessions.tsx b/web/src/pages/next-chats/chat/sessions.tsx index f37976d3e..9bbd6c1bb 100644 --- a/web/src/pages/next-chats/chat/sessions.tsx +++ b/web/src/pages/next-chats/chat/sessions.tsx @@ -16,7 +16,7 @@ import { useSetModalState } from '@/hooks/common-hooks'; import { useFetchChat, useGetChatSearchParams, - useRemoveConversation, + useRemoveSessions, } from '@/hooks/use-chat-request'; import { LucideCopyX, @@ -50,7 +50,7 @@ export function Sessions({ handleConversationCardClick }: SessionProps) { } = useSelectDerivedConversationList(); const { data } = useFetchChat(); const { visible, switchVisible } = useSetModalState(true); - const { removeConversation } = useRemoveConversation(); + const { removeSessions } = useRemoveSessions(); const { setConversationBoth } = useChatUrlParams(); const { conversationId } = useGetChatSearchParams(); @@ -118,7 +118,7 @@ export function Sessions({ handleConversationCardClick }: SessionProps) { let removeCode = -1; if (persistedIds.length > 0) { - removeCode = await removeConversation(persistedIds); + removeCode = await removeSessions(persistedIds); } if (currentConversationDeleted && conversationId) { @@ -136,7 +136,7 @@ export function Sessions({ handleConversationCardClick }: SessionProps) { conversationList, setConversationBoth, removeTemporaryConversation, - removeConversation, + removeSessions, exitSelectionMode, ]); diff --git a/web/src/pages/next-chats/hooks/use-chat-url.ts b/web/src/pages/next-chats/hooks/use-chat-url.ts index a1c736bb9..4006b2cd2 100644 --- a/web/src/pages/next-chats/hooks/use-chat-url.ts +++ b/web/src/pages/next-chats/hooks/use-chat-url.ts @@ -1,7 +1,6 @@ import { ChatSearchParams } from '@/constants/chat'; import { useGetChatSearchParams } from '@/hooks/use-chat-request'; import { IMessage } from '@/interfaces/database/chat'; -import { generateConversationId } from '@/utils/chat'; import { useCallback, useMemo } from 'react'; import { useSearchParams } from 'react-router'; import { useSetConversation } from './use-set-conversation'; @@ -57,38 +56,32 @@ export const useChatUrlParams = () => { export function useCreateConversationBeforeSendMessage() { const { conversationId, isNew } = useGetChatSearchParams(); const { setConversation } = useSetConversation(); - const { setIsNew, setConversationBoth } = useChatUrlParams(); + const { setConversationBoth } = useChatUrlParams(); // Create conversation if it doesn't exist const createConversationBeforeSendMessage = useCallback( async (value: string) => { let currentMessages: Array = []; - const currentConversationId = generateConversationId(); if (conversationId === '' || isNew === 'true') { - if (conversationId === '') { - setConversationBoth(currentConversationId, 'true'); - } - const data = await setConversation( - value, - true, - conversationId || currentConversationId, - ); - if (data.code !== 0) { + const data = await setConversation(value); + if (!data || data.code !== 0) { return; - } else { - setIsNew(''); - currentMessages = data.data.message; } + const backendConvId = data.data.id; + setConversationBoth(backendConvId, ''); + currentMessages = data.data.messages; + return { + targetConversationId: backendConvId, + currentMessages, + }; } - const targetConversationId = conversationId || currentConversationId; - return { - targetConversationId, + targetConversationId: conversationId, currentMessages, }; }, - [conversationId, isNew, setConversation, setConversationBoth, setIsNew], + [conversationId, isNew, setConversation, setConversationBoth], ); return { diff --git a/web/src/pages/next-chats/hooks/use-create-conversation.ts b/web/src/pages/next-chats/hooks/use-create-conversation.ts index 6b06b9d0d..76e022673 100644 --- a/web/src/pages/next-chats/hooks/use-create-conversation.ts +++ b/web/src/pages/next-chats/hooks/use-create-conversation.ts @@ -12,7 +12,7 @@ export const useCreateConversationBeforeUploadDocument = () => { async (message: string) => { const isNew = getIsNew(); if (isNew === 'true') { - const data = await setConversation(message, true); + const data = await setConversation(message); return data; } diff --git a/web/src/pages/next-chats/hooks/use-select-conversation-list.ts b/web/src/pages/next-chats/hooks/use-select-conversation-list.ts index b553dcbc4..674cb9f8a 100644 --- a/web/src/pages/next-chats/hooks/use-select-conversation-list.ts +++ b/web/src/pages/next-chats/hooks/use-select-conversation-list.ts @@ -2,7 +2,7 @@ import { MessageType } from '@/constants/chat'; import { useTranslate } from '@/hooks/common-hooks'; import { useFetchChatList, - useFetchConversationList, + useFetchSessionList, } from '@/hooks/use-chat-request'; import { IConversation } from '@/interfaces/database/chat'; import { generateConversationId } from '@/utils/chat'; @@ -30,7 +30,7 @@ export const useSelectDerivedConversationList = () => { loading, handleInputChange, searchString, - } = useFetchConversationList(); + } = useFetchSessionList(); const { id: dialogId } = useParams(); const prologue = useFindPrologueFromDialogList(); @@ -45,9 +45,9 @@ export const useSelectDerivedConversationList = () => { { id: conversationId, name: t('newConversation'), - dialog_id: dialogId, + chat_id: dialogId, is_new: true, - message: [ + messages: [ { content: prologue, role: MessageType.Assistant, diff --git a/web/src/pages/next-chats/hooks/use-send-chat-message.ts b/web/src/pages/next-chats/hooks/use-send-chat-message.ts index f7a7e2842..6997d5776 100644 --- a/web/src/pages/next-chats/hooks/use-send-chat-message.ts +++ b/web/src/pages/next-chats/hooks/use-send-chat-message.ts @@ -70,9 +70,8 @@ export const useSendMessage = (controller: AbortController) => { const { handleUploadFile, isUploading, removeFile, files, clearFiles } = useUploadFile(); - const { send, answer, done } = useSendMessageWithSse( - api.completeConversation, - ); + const { id: chatId } = useParams(); + const { send, answer, done } = useSendMessageWithSse(); const { scrollRef, messageContainerRef, @@ -97,9 +96,10 @@ export const useSendMessage = (controller: AbortController) => { currentConversationId?: string; messages?: IMessage[]; } & NextMessageInputOnPressEnterParameter) => { + const sessionId = currentConversationId ?? conversationId; const res = await send( + api.completionUrl(chatId!, sessionId), { - conversation_id: currentConversationId ?? conversationId, messages: [ ...(Array.isArray(messages) && messages?.length > 0 ? messages @@ -122,6 +122,7 @@ export const useSendMessage = (controller: AbortController) => { [ derivedMessages, conversationId, + chatId, removeLatestMessage, setValue, send, diff --git a/web/src/pages/next-chats/hooks/use-send-shared-message.ts b/web/src/pages/next-chats/hooks/use-send-shared-message.ts index ba160168d..de99d344e 100644 --- a/web/src/pages/next-chats/hooks/use-send-shared-message.ts +++ b/web/src/pages/next-chats/hooks/use-send-shared-message.ts @@ -6,7 +6,6 @@ import { useSelectDerivedMessages, useSendMessageWithSse, } from '@/hooks/logic-hooks'; -import { useCreateNextSharedConversation } from '@/hooks/use-chat-request'; import { Message } from '@/interfaces/database/chat'; import { get } from 'lodash'; import trim from 'lodash/trim'; @@ -47,12 +46,9 @@ export const useSendSharedMessage = () => { sharedId: conversationId, data: data, } = useGetSharedChatSearchParams(); - const { createSharedConversation: setConversation } = - useCreateNextSharedConversation(); const { handleInputChange, value, setValue } = useHandleMessageInputChange(); - const { send, answer, done, stopOutputMessage } = useSendMessageWithSse( - `/api/v1/${from === SharedFrom.Agent ? 'agentbots' : 'chatbots'}/${conversationId}/completions`, - ); + const completionUrl = `/api/v1/${from === SharedFrom.Agent ? 'agentbots' : 'chatbots'}/${conversationId}/completions`; + const { send, answer, done, stopOutputMessage } = useSendMessageWithSse(); const { derivedMessages, removeLatestMessage, @@ -72,7 +68,7 @@ export const useSendSharedMessage = () => { enableThinking?: boolean, enableInternet?: boolean, ) => { - const res = await send({ + const res = await send(completionUrl, { conversation_id: id ?? conversationId, quote: true, question: message.content, @@ -87,7 +83,14 @@ export const useSendSharedMessage = () => { removeLatestMessage(); } }, - [send, conversationId, derivedMessages, setValue, removeLatestMessage], + [ + send, + completionUrl, + conversationId, + derivedMessages, + setValue, + removeLatestMessage, + ], ); const handleSendMessage = useCallback( @@ -96,27 +99,19 @@ export const useSendSharedMessage = () => { enableThinking?: boolean, enableInternet?: boolean, ) => { - if (conversationId !== '') { - sendMessage(message, undefined, enableThinking, enableInternet); - } else { - const data = await setConversation('user id'); - if (data.code === 0) { - const id = data.data.id; - sendMessage(message, id, enableThinking, enableInternet); - } - } + sendMessage(message, undefined, enableThinking, enableInternet); }, - [conversationId, setConversation, sendMessage], + [sendMessage], ); const fetchSessionId = useCallback(async () => { const payload = { question: '' }; - const ret = await send({ ...payload, ...data }); + const ret = await send(completionUrl, { ...payload, ...data }); if (isCompletionError(ret)) { message.error(ret?.data.message); setHasError(true); } - }, [send]); + }, [send, completionUrl]); useEffect(() => { fetchSessionId(); diff --git a/web/src/pages/next-chats/hooks/use-send-single-message.ts b/web/src/pages/next-chats/hooks/use-send-single-message.ts index ef624ef60..6dcf7d597 100644 --- a/web/src/pages/next-chats/hooks/use-send-single-message.ts +++ b/web/src/pages/next-chats/hooks/use-send-single-message.ts @@ -9,6 +9,7 @@ import { useGetChatSearchParams } from '@/hooks/use-chat-request'; import { IMessage } from '@/interfaces/database/chat'; import api from '@/utils/api'; import { useCallback, useEffect } from 'react'; +import { useParams } from 'react-router'; import { v4 as uuid } from 'uuid'; import { CreateConversationBeforeSendMessageReturnType } from './use-chat-url'; import { useUploadFile } from './use-upload-file'; @@ -29,10 +30,9 @@ export function useSendSingleMessage({ } & Pick, 'value' | 'setValue'> & Pick, 'files' | 'clearFiles'>) { const { conversationId } = useGetChatSearchParams(); + const { id: chatId } = useParams(); - const { send, answer, done } = useSendMessageWithSse( - api.completeConversation, - ); + const { send, answer, done } = useSendMessageWithSse(); const { scrollRef, @@ -65,9 +65,10 @@ export function useSendSingleMessage({ currentConversationId?: string; messages?: IMessage[]; } & NextMessageInputOnPressEnterParameter) => { + const sessionId = currentConversationId ?? conversationId; const res = await send( + api.completionUrl(chatId!, sessionId), { - conversation_id: currentConversationId ?? conversationId, messages: [ ...(Array.isArray(messages) && messages?.length > 0 ? messages diff --git a/web/src/pages/next-chats/hooks/use-set-conversation.ts b/web/src/pages/next-chats/hooks/use-set-conversation.ts index 4fe608043..e627e923f 100644 --- a/web/src/pages/next-chats/hooks/use-set-conversation.ts +++ b/web/src/pages/next-chats/hooks/use-set-conversation.ts @@ -1,35 +1,17 @@ -import { MessageType } from '@/constants/chat'; -import { useUpdateConversation } from '@/hooks/use-chat-request'; +import { useCreateSession } from '@/hooks/use-chat-request'; import { useCallback } from 'react'; import { useParams } from 'react-router'; export const useSetConversation = () => { - const { id: dialogId } = useParams(); - const { updateConversation } = useUpdateConversation(); + const { id: chatId } = useParams(); + const { createSession } = useCreateSession(); const setConversation = useCallback( - async ( - message: string, - isNew: boolean = false, - conversationId?: string, - ) => { - const data = await updateConversation({ - dialog_id: dialogId, - name: message, - is_new: isNew, - conversation_id: conversationId, - message: [ - { - role: MessageType.Assistant, - content: message, - conversationId, - }, - ], - }); - + async (name: string) => { + const data = await createSession({ chatId: chatId!, name }); return data; }, - [updateConversation, dialogId], + [createSession, chatId], ); return { setConversation }; diff --git a/web/src/pages/next-chats/hooks/use-upload-file.ts b/web/src/pages/next-chats/hooks/use-upload-file.ts index a38015e3d..b2bf87104 100644 --- a/web/src/pages/next-chats/hooks/use-upload-file.ts +++ b/web/src/pages/next-chats/hooks/use-upload-file.ts @@ -3,7 +3,6 @@ import { useGetChatSearchParams, useUploadAndParseFile, } from '@/hooks/use-chat-request'; -import { generateConversationId } from '@/utils/chat'; import { useCallback, useState } from 'react'; import { useChatUrlParams } from './use-chat-url'; import { useSetConversation } from './use-set-conversation'; @@ -16,7 +15,7 @@ export function useUploadFile() { ); const { setConversation } = useSetConversation(); const { conversationId, isNew } = useGetChatSearchParams(); - const { setIsNew, setConversationBoth } = useChatUrlParams(); + const { setConversationBoth } = useChatUrlParams(); type FileUploadParameters = Parameters< NonNullable @@ -58,20 +57,11 @@ export function useUploadFile() { Array.isArray(files) && files.length ) { - const currentConversationId = generateConversationId(); - - if (conversationId === '') { - setConversationBoth(currentConversationId, 'true'); - } - - const data = await setConversation( - files[0].name, - true, - conversationId || currentConversationId, - ); - if (data.code === 0) { - setIsNew(''); - handleUploadFile(files, options, data.data?.id); + const data = await setConversation(files[0].name); + if (data?.code === 0) { + const backendConvId = data.data.id; + setConversationBoth(backendConvId, ''); + handleUploadFile(files, options, backendConvId); } } else { handleUploadFile(files, options); @@ -83,7 +73,6 @@ export function useUploadFile() { isNew, setConversation, setConversationBoth, - setIsNew, ], ); diff --git a/web/src/pages/next-chats/share/index.tsx b/web/src/pages/next-chats/share/index.tsx index cb26eed9d..dd109dccc 100644 --- a/web/src/pages/next-chats/share/index.tsx +++ b/web/src/pages/next-chats/share/index.tsx @@ -6,13 +6,10 @@ import { useClickDrawer } from '@/components/pdf-drawer/hooks'; import { useSyncThemeFromParams } from '@/components/theme-provider'; import { MessageType, SharedFrom } from '@/constants/chat'; import { useFetchFlowSSE } from '@/hooks/use-agent-request'; -import { - useFetchExternalChatInfo, - useFetchNextConversationSSE, -} from '@/hooks/use-chat-request'; +import { useFetchExternalChatInfo } from '@/hooks/use-chat-request'; import i18n, { changeLanguageAsync } from '@/locales/config'; import { buildMessageUuidWithRole } from '@/utils/chat'; -import React, { forwardRef, useMemo } from 'react'; +import React, { forwardRef } from 'react'; import { useSendButtonDisabled } from '../hooks/use-button-disabled'; import { useGetSharedChatSearchParams, @@ -47,18 +44,15 @@ const ChatContainer = () => { const sendDisabled = useSendButtonDisabled(value); const { data: chatInfo } = useFetchExternalChatInfo(); - const useFetchAvatar = useMemo(() => { - return from === SharedFrom.Agent - ? useFetchFlowSSE - : useFetchNextConversationSSE; - }, [from]); + const { data: flowData } = useFetchFlowSSE(); React.useEffect(() => { if (locale && i18n.language !== locale) { changeLanguageAsync(locale); } }, [locale, visibleAvatar]); - const { data: avatarData } = useFetchAvatar(); + const avatarDialogSrc = + from === SharedFrom.Agent ? flowData?.avatar : chatInfo.avatar; if (!conversationId) { return
empty
; @@ -84,12 +78,12 @@ const ChatContainer = () => { { }; export const buildMessageItemReference = ( - conversation: { message: IMessage[]; reference: IReference[] }, + conversation: { messages: IMessage[]; reference: IReference[] }, message: IMessage, ) => { - const assistantMessages = conversation.message + const assistantMessages = conversation.messages ?.filter( (x) => x.role === MessageType.Assistant && !x.content.startsWith('**ERROR**:'), // Exclude error messages diff --git a/web/src/pages/next-search/hooks.ts b/web/src/pages/next-search/hooks.ts index 3266c6956..d68dbc113 100644 --- a/web/src/pages/next-search/hooks.ts +++ b/web/src/pages/next-search/hooks.ts @@ -68,7 +68,7 @@ export const useSearchFetchMindMap = () => { const sharedId = searchParams.get('shared_id'); const fetchMindMapFunc = sharedId ? searchService.mindmapShare - : chatService.getMindMap; + : chatService.chatsMindmap; const { data, isPending: loading, @@ -280,7 +280,7 @@ export const useFetchRelatedQuestions = ( const shared_id = searchParams.get('shared_id'); const retrievalTestFunc = shared_id ? searchService.getRelatedQuestionsShare - : chatService.getRelatedQuestions; + : chatService.chatsRelatedQuestions; const { data, isPending: loading, @@ -309,9 +309,8 @@ export const useSendQuestion = ( related_search: boolean = false, ) => { const { sharedId } = useGetSharedSearchParams(); - const { send, answer, done, stopOutputMessage } = useSendMessageWithSse( - sharedId ? api.askShare : api.ask, - ); + const askUrl = sharedId ? api.askShare : api.ask; + const { send, answer, done, stopOutputMessage } = useSendMessageWithSse(); const { testChunk, loading } = useTestChunkRetrieval(tenantId); const { testChunkAll } = useTestChunkAllRetrieval(tenantId); @@ -334,7 +333,12 @@ export const useSendQuestion = ( setCurrentAnswer({} as IAnswer); if (enableAI) { setSendingLoading(true); - send({ kb_ids: kbIds, question: q, tenantId, search_id: searchId }); + send(askUrl, { + kb_ids: kbIds, + question: q, + tenantId, + search_id: searchId, + }); } testChunk({ kb_id: kbIds, diff --git a/web/src/services/next-chat-service.ts b/web/src/services/next-chat-service.ts index d45840170..aa9119418 100644 --- a/web/src/services/next-chat-service.ts +++ b/web/src/services/next-chat-service.ts @@ -9,26 +9,21 @@ const { patchChat, deleteChat, bulkDeleteChats, - getConversation, - getConversationSSE, - setConversation, - completeConversation, - listConversation, - removeConversation, + createSession, + listSessions, + getSession, + updateSession, + removeSessions, + deleteMessage, + thumbup, createToken, listToken, removeToken, getStats, - createExternalConversation, - getExternalConversation, - completeExternalConversation, - uploadAndParseExternal, - deleteMessage, - thumbup, - tts, + chatsTts, ask, - mindmap, - getRelatedQuestions, + chatsMindmap, + chatsRelatedQuestions, upload_and_parse, fetchExternalChatInfo, } = api; @@ -62,29 +57,33 @@ const methods = { url: bulkDeleteChats, method: 'delete', }, - listConversation: { - url: listConversation, - method: 'get', - }, - getConversation: { - url: getConversation, - method: 'get', - }, - getConversationSSE: { - url: getConversationSSE, - method: 'get', - }, - setConversation: { - url: setConversation, + createSession: { + url: createSession, method: 'post', }, - completeConversation: { - url: completeConversation, - method: 'post', + listSessions: { + url: listSessions, + method: 'get', }, - removeConversation: { - url: removeConversation, - method: 'post', + getSession: { + url: getSession, + method: 'get', + }, + updateSession: { + url: updateSession, + method: 'put', + }, + removeSessions: { + url: removeSessions, + method: 'delete', + }, + deleteMessage: { + url: deleteMessage, + method: 'delete', + }, + thumbup: { + url: thumbup, + method: 'put', }, createToken: { url: createToken, @@ -102,44 +101,20 @@ const methods = { url: getStats, method: 'get', }, - createExternalConversation: { - url: createExternalConversation, - method: 'get', - }, - getExternalConversation: { - url: getExternalConversation, - method: 'get', - }, - completeExternalConversation: { - url: completeExternalConversation, - method: 'post', - }, - uploadAndParseExternal: { - url: uploadAndParseExternal, - method: 'post', - }, - deleteMessage: { - url: deleteMessage, - method: 'post', - }, - thumbup: { - url: thumbup, - method: 'post', - }, - tts: { - url: tts, + chatsTts: { + url: chatsTts, method: 'post', }, ask: { url: ask, method: 'post', }, - getMindMap: { - url: mindmap, + chatsMindmap: { + url: chatsMindmap, method: 'post', }, - getRelatedQuestions: { - url: getRelatedQuestions, + chatsRelatedQuestions: { + url: chatsRelatedQuestions, method: 'post', }, uploadAndParse: { diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 440614d7c..2880ca2e0 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -52,7 +52,7 @@ export default { // plugin llm_tools: `${api_host}/plugin/llm_tools`, - sequence2txt: `${api_host}/conversation/sequence2txt`, + chatsTranscriptions: `${ExternalApi}${api_host}/chats/transcriptions`, // knowledge base @@ -135,28 +135,31 @@ export default { patchChat: (chatId: string) => `${ExternalApi}${api_host}/chats/${chatId}`, deleteChat: (chatId: string) => `${ExternalApi}${api_host}/chats/${chatId}`, bulkDeleteChats: `${ExternalApi}${api_host}/chats`, - setConversation: `${api_host}/conversation/set`, - getConversation: `${api_host}/conversation/get`, - getConversationSSE: (dialogId: string) => - `${api_host}/conversation/getsse/${dialogId}`, - listConversation: `${api_host}/conversation/list`, - removeConversation: `${api_host}/conversation/rm`, - completeConversation: `${api_host}/conversation/completion`, - deleteMessage: `${api_host}/conversation/delete_msg`, - thumbup: `${api_host}/conversation/thumbup`, - tts: `${api_host}/conversation/tts`, - ask: `${api_host}/conversation/ask`, - mindmap: `${api_host}/conversation/mindmap`, - getRelatedQuestions: `${api_host}/conversation/related_questions`, + createSession: (chatId: string) => + `${ExternalApi}${api_host}/chats/${chatId}/sessions`, + listSessions: (chatId: string) => + `${ExternalApi}${api_host}/chats/${chatId}/sessions`, + getSession: (chatId: string, sessionId: string) => + `${ExternalApi}${api_host}/chats/${chatId}/sessions/${sessionId}`, + updateSession: (chatId: string, sessionId: string) => + `${ExternalApi}${api_host}/chats/${chatId}/sessions/${sessionId}`, + removeSessions: (chatId: string) => + `${ExternalApi}${api_host}/chats/${chatId}/sessions`, + deleteMessage: (chatId: string, sessionId: string, msgId: string) => + `${ExternalApi}${api_host}/chats/${chatId}/sessions/${sessionId}/messages/${msgId}`, + thumbup: (chatId: string, sessionId: string, msgId: string) => + `${ExternalApi}${api_host}/chats/${chatId}/sessions/${sessionId}/messages/${msgId}/feedback`, + completionUrl: (chatId: string, sessionId: string) => + `${ExternalApi}${api_host}/chats/${chatId}/sessions/${sessionId}/completions`, + chatsTts: `${ExternalApi}${api_host}/chats/tts`, + ask: `${ExternalApi}${api_host}/chats/ask`, + chatsMindmap: `${ExternalApi}${api_host}/chats/mindmap`, + chatsRelatedQuestions: `${ExternalApi}${api_host}/chats/related_questions`, // chat for external createToken: `${api_host}/api/new_token`, listToken: `${api_host}/api/token_list`, removeToken: `${api_host}/api/rm`, getStats: `${api_host}/api/stats`, - createExternalConversation: `${api_host}/api/new_conversation`, - getExternalConversation: `${api_host}/api/conversation`, - completeExternalConversation: `${api_host}/api/completion`, - uploadAndParseExternal: `${api_host}/api/document/upload_and_parse`, // next chat fetchExternalChatInfo: (id: string) =>