diff --git a/api/apps/restful_apis/openai_api.py b/api/apps/restful_apis/openai_api.py new file mode 100644 index 000000000..320ecd09d --- /dev/null +++ b/api/apps/restful_apis/openai_api.py @@ -0,0 +1,309 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import time + +from quart import Response, jsonify + +from api.apps import current_user, login_required +from api.db.services.dialog_service import DialogService, async_chat +from api.db.services.doc_metadata_service import DocMetadataService +from api.db.services.tenant_llm_service import TenantLLMService +from api.utils.api_utils import get_error_data_result, get_request_json, validate_request +from common.constants import RetCode, StatusEnum +from common.metadata_utils import convert_conditions, meta_filter +from common.token_utils import num_tokens_from_string +from rag.prompts.generator import chunks_format + +def _validate_llm_id(llm_id, tenant_id, llm_setting=None): + if not llm_id: + return None + + llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(llm_id) + model_type = (llm_setting or {}).get("model_type") + if model_type not in {"chat", "image2text"}: + model_type = "chat" + + if not TenantLLMService.query( + tenant_id=tenant_id, + llm_name=llm_name, + llm_factory=llm_factory, + model_type=model_type, + ): + return f"`llm_id` {llm_id} doesn't exist" + return None + + +def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None): + chunks = chunks_format(reference) + if not include_metadata: + return chunks + + doc_ids_by_kb = {} + for chunk in chunks: + kb_id = chunk.get("dataset_id") + doc_id = chunk.get("document_id") + if not kb_id or not doc_id: + continue + doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id) + + if not doc_ids_by_kb: + return chunks + + meta_by_doc = {} + for kb_id, doc_ids in doc_ids_by_kb.items(): + meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id) + if meta_map: + meta_by_doc.update(meta_map) + + if metadata_fields is not None: + metadata_fields = {f for f in metadata_fields if isinstance(f, str)} + if not metadata_fields: + return chunks + + for chunk in chunks: + doc_id = chunk.get("document_id") + if not doc_id: + continue + meta = meta_by_doc.get(doc_id) + if not meta: + continue + if metadata_fields is not None: + meta = {k: v for k, v in meta.items() if k in metadata_fields} + if meta: + chunk["document_metadata"] = meta + + return chunks + + +def _build_sse_response(body): + resp = Response(body, mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + +@manager.route("/openai//chat/completions", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("model", "messages") +async def openai_chat_completions(chat_id): + req = await get_request_json() + + extra_body = req.get("extra_body") or {} + if extra_body and not isinstance(extra_body, dict): + return get_error_data_result("extra_body must be an object.") + + need_reference = bool(extra_body.get("reference", False)) + reference_metadata = extra_body.get("reference_metadata") or {} + if reference_metadata and not isinstance(reference_metadata, dict): + return get_error_data_result("reference_metadata must be an object.") + include_reference_metadata = bool(reference_metadata.get("include", False)) + metadata_fields = reference_metadata.get("fields") + if metadata_fields is not None and not isinstance(metadata_fields, list): + return get_error_data_result("reference_metadata.fields must be an array.") + + messages = req.get("messages", []) + if len(messages) < 1: + return get_error_data_result("You have to provide messages.") + if messages[-1]["role"] != "user": + return get_error_data_result("The last content of this conversation is not from user.") + + prompt = messages[-1]["content"] + context_token_used = sum(num_tokens_from_string(message["content"]) for message in messages) + requested_model = req.get("model", "") or "" + completion_id = f"chatcmpl-{chat_id}" + + dia = DialogService.query(tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value) + if not dia: + return get_error_data_result(f"You don't own the chat {chat_id}") + dia = dia[0] + + using_placeholder_model = requested_model == "model" + if using_placeholder_model: + requested_model = dia.llm_id or requested_model + else: + llm_id_error = _validate_llm_id(requested_model, current_user.id, {"model_type": "chat"}) + if llm_id_error: + return get_error_data_result(message=llm_id_error, code=RetCode.ARGUMENT_ERROR) + dia.llm_id = requested_model + if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=requested_model): + return get_error_data_result(message=f"Cannot use specified model {requested_model}.") + + metadata_condition = extra_body.get("metadata_condition") or {} + if metadata_condition and not isinstance(metadata_condition, dict): + return get_error_data_result(message="metadata_condition must be an object.") + + doc_ids_str = None + if metadata_condition: + metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or []) + filtered_doc_ids = meta_filter( + metas, + convert_conditions(metadata_condition), + metadata_condition.get("logic", "and"), + ) + if metadata_condition.get("conditions") and not filtered_doc_ids: + filtered_doc_ids = ["-999"] + doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None + + msg = [] + for message in messages: + if message["role"] == "system": + continue + if message["role"] == "assistant" and not msg: + continue + msg.append(message) + + tools = None + toolcall_session = None + stream_mode = req.get("stream", True) + + if stream_mode: + async def streamed_response_generator(): + token_used = 0 + last_ans = {} + full_content = "" + final_answer = None + final_reference = None + in_think = False + response = { + "id": completion_id, + "choices": [ + { + "delta": { + "content": "", + "role": "assistant", + "function_call": None, + "tool_calls": None, + "reasoning_content": "", + }, + "finish_reason": None, + "index": 0, + "logprobs": None, + } + ], + "created": int(time.time()), + "model": requested_model, + "object": "chat.completion.chunk", + "system_fingerprint": "", + "usage": None, + } + + try: + chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} + if doc_ids_str: + chat_kwargs["doc_ids"] = doc_ids_str + async for ans in async_chat(dia, msg, True, **chat_kwargs): + last_ans = ans + if ans.get("final"): + if ans.get("answer"): + full_content = ans["answer"] + response["choices"][0]["delta"]["content"] = full_content + response["choices"][0]["delta"]["reasoning_content"] = None + yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" + final_answer = full_content + final_reference = ans.get("reference", {}) + continue + if ans.get("start_to_think"): + in_think = True + continue + if ans.get("end_to_think"): + in_think = False + continue + delta = ans.get("answer") or "" + if not delta: + continue + token_used += num_tokens_from_string(delta) + if in_think: + response["choices"][0]["delta"]["reasoning_content"] = delta + response["choices"][0]["delta"]["content"] = None + else: + full_content += delta + response["choices"][0]["delta"]["content"] = delta + response["choices"][0]["delta"]["reasoning_content"] = None + yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" + except Exception as e: + response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e) + yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" + + response["choices"][0]["delta"]["content"] = None + response["choices"][0]["delta"]["reasoning_content"] = None + response["choices"][0]["finish_reason"] = "stop" + prompt_tokens = num_tokens_from_string(prompt) + response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": token_used, + "total_tokens": prompt_tokens + token_used, + } + if need_reference: + reference_payload = final_reference if final_reference is not None else last_ans.get("reference", []) + response["choices"][0]["delta"]["reference"] = _build_reference_chunks( + reference_payload, + include_metadata=include_reference_metadata, + metadata_fields=metadata_fields, + ) + response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content + yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" + yield "data:[DONE]\n\n" + + return _build_sse_response(streamed_response_generator()) + + answer = None + chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} + if doc_ids_str: + chat_kwargs["doc_ids"] = doc_ids_str + async for ans in async_chat(dia, msg, False, **chat_kwargs): + answer = ans + break + + content = answer["answer"] + response = { + "id": completion_id, + "object": "chat.completion", + "created": int(time.time()), + "model": requested_model, + "usage": { + "prompt_tokens": num_tokens_from_string(prompt), + "completion_tokens": num_tokens_from_string(content), + "total_tokens": num_tokens_from_string(prompt) + num_tokens_from_string(content), + "completion_tokens_details": { + "reasoning_tokens": context_token_used, + "accepted_prediction_tokens": num_tokens_from_string(content), + "rejected_prediction_tokens": 0, + }, + }, + "choices": [ + { + "message": { + "role": "assistant", + "content": content, + }, + "logprobs": None, + "finish_reason": "stop", + "index": 0, + } + ], + } + if need_reference: + response["choices"][0]["message"]["reference"] = _build_reference_chunks( + answer.get("reference", {}), + include_metadata=include_reference_metadata, + metadata_fields=metadata_fields, + ) + + return jsonify(response) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 92f01233c..0eaf45b1e 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -15,30 +15,23 @@ # import json import re -import time -import os -import tempfile import logging -from quart import Response, jsonify, request - -from common.token_utils import num_tokens_from_string +from quart import Response, request from agent.canvas import Canvas from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import completion as agent_completion -from api.db.services.conversation_service import ConversationService from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.conversation_service import async_iframe_completion as iframe_completion -from api.db.services.conversation_service import async_completion as rag_completion -from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap +from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter +from common.metadata_utils import apply_meta_data_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \ @@ -48,8 +41,8 @@ from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_ get_result, get_request_json, server_error_response, token_required, validate_request from rag.app.tag import label_question from rag.prompts.template import load_prompt -from rag.prompts.generator import cross_languages, keyword_extraction, chunks_format -from common.constants import RetCode, LLMType, StatusEnum +from rag.prompts.generator import cross_languages, keyword_extraction +from common.constants import RetCode, LLMType from common import settings @@ -90,349 +83,6 @@ async def create_agent_session(tenant_id, agent_id): return get_result(data=conv) -@manager.route("/chats//completions", methods=["POST"]) # noqa: F821 -@token_required -async def chat_completion(tenant_id, chat_id): - req = await get_request_json() - if not req: - req = {"question": ""} - if not req.get("session_id"): - req["question"] = "" - dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value) - if not dia: - return get_error_data_result(f"You don't own the chat {chat_id}") - dia = dia[0] - if req.get("session_id"): - if not ConversationService.query(id=req["session_id"], dialog_id=chat_id): - return get_error_data_result(f"You don't own the session {req['session_id']}") - - metadata_condition = req.get("metadata_condition") or {} - if metadata_condition and not isinstance(metadata_condition, dict): - return get_error_data_result(message="metadata_condition must be an object.") - - if metadata_condition and req.get("question"): - metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or []) - filtered_doc_ids = meta_filter( - metas, - convert_conditions(metadata_condition), - metadata_condition.get("logic", "and"), - ) - if metadata_condition.get("conditions") and not filtered_doc_ids: - filtered_doc_ids = ["-999"] - - if filtered_doc_ids: - req["doc_ids"] = ",".join(filtered_doc_ids) - else: - req.pop("doc_ids", None) - - if req.get("stream", True): - resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - - return resp - else: - answer = None - async for ans in rag_completion(tenant_id, chat_id, **req): - answer = ans - break - return get_result(data=answer) - - -@manager.route("/chats_openai//chat/completions", methods=["POST"]) # noqa: F821 -@validate_request("model", "messages") # noqa: F821 -@token_required -async def chat_completion_openai_like(tenant_id, chat_id): - """ - OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint. - - This function allows users to interact with a model and receive responses based on a series of historical messages. - If `stream` is set to True (by default), the response will be streamed in chunks, mimicking the OpenAI-style API. - Set `stream` to False explicitly, the response will be returned in a single complete answer. - - Reference: - - - If `stream` is True, the final answer and reference information will appear in the **last chunk** of the stream. - - If `stream` is False, the reference will be included in `choices[0].message.reference`. - - If `extra_body.reference_metadata.include` is True, each reference chunk may include `document_metadata` in both streaming and non-streaming responses. - - Example usage: - - curl -X POST https://ragflow_address.com/api/v1/chats_openai//chat/completions \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer $RAGFLOW_API_KEY" \ - -d '{ - "model": "model", - "messages": [{"role": "user", "content": "Say this is a test!"}], - "stream": true - }' - - Alternatively, you can use Python's `OpenAI` client: - - NOTE: Streaming via `client.chat.completions.create(stream=True, ...)` does - not return `reference` currently. The only way to return `reference` is - non-stream mode with `with_raw_response`. - - from openai import OpenAI - import json - - model = "model" - client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/") - - stream = True - reference = True - - request_kwargs = dict( - model="model", - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Who are you?"}, - {"role": "assistant", "content": "I am an AI assistant named..."}, - {"role": "user", "content": "Can you tell me how to install neovim"}, - ], - extra_body={ - "reference": reference, - "reference_metadata": { - "include": True, - "fields": ["author", "year", "source"], - }, - "metadata_condition": { - "logic": "and", - "conditions": [ - { - "name": "author", - "comparison_operator": "is", - "value": "bob" - } - ] - } - }, - ) - - if stream: - completion = client.chat.completions.create(stream=True, **request_kwargs) - for chunk in completion: - print(chunk) - else: - resp = client.chat.completions.with_raw_response.create( - stream=False, **request_kwargs - ) - print("status:", resp.http_response.status_code) - raw_text = resp.http_response.text - print("raw:", raw_text) - - data = json.loads(raw_text) - print("assistant:", data["choices"][0]["message"].get("content")) - print("reference:", data["choices"][0]["message"].get("reference")) - - """ - req = await get_request_json() - - extra_body = req.get("extra_body") or {} - if extra_body and not isinstance(extra_body, dict): - return get_error_data_result("extra_body must be an object.") - - need_reference = bool(extra_body.get("reference", False)) - reference_metadata = extra_body.get("reference_metadata") or {} - if reference_metadata and not isinstance(reference_metadata, dict): - return get_error_data_result("reference_metadata must be an object.") - include_reference_metadata = bool(reference_metadata.get("include", False)) - metadata_fields = reference_metadata.get("fields") - if metadata_fields is not None and not isinstance(metadata_fields, list): - return get_error_data_result("reference_metadata.fields must be an array.") - - messages = req.get("messages", []) - # To prevent empty [] input - if len(messages) < 1: - return get_error_data_result("You have to provide messages.") - if messages[-1]["role"] != "user": - return get_error_data_result("The last content of this conversation is not from user.") - - prompt = messages[-1]["content"] - # Treat context tokens as reasoning tokens - context_token_used = sum(num_tokens_from_string(message["content"]) for message in messages) - - dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value) - if not dia: - return get_error_data_result(f"You don't own the chat {chat_id}") - dia = dia[0] - - metadata_condition = extra_body.get("metadata_condition") or {} - if metadata_condition and not isinstance(metadata_condition, dict): - return get_error_data_result(message="metadata_condition must be an object.") - - doc_ids_str = None - if metadata_condition: - metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or []) - filtered_doc_ids = meta_filter( - metas, - convert_conditions(metadata_condition), - metadata_condition.get("logic", "and"), - ) - if metadata_condition.get("conditions") and not filtered_doc_ids: - filtered_doc_ids = ["-999"] - doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None - - # Filter system and non-sense assistant messages - msg = [] - for m in messages: - if m["role"] == "system": - continue - if m["role"] == "assistant" and not msg: - continue - msg.append(m) - - # tools = get_tools() - # toolcall_session = SimpleFunctionCallServer() - tools = None - toolcall_session = None - - if req.get("stream", True): - # The value for the usage field on all chunks except for the last one will be null. - # The usage field on the last chunk contains token usage statistics for the entire request. - # The choices field on the last chunk will always be an empty array []. - async def streamed_response_generator(chat_id, dia, msg): - token_used = 0 - last_ans = {} - full_content = "" - full_reasoning = "" - final_answer = None - final_reference = None - in_think = False - response = { - "id": f"chatcmpl-{chat_id}", - "choices": [ - { - "delta": { - "content": "", - "role": "assistant", - "function_call": None, - "tool_calls": None, - "reasoning_content": "", - }, - "finish_reason": None, - "index": 0, - "logprobs": None, - } - ], - "created": int(time.time()), - "model": "model", - "object": "chat.completion.chunk", - "system_fingerprint": "", - "usage": None, - } - - try: - chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} - if doc_ids_str: - chat_kwargs["doc_ids"] = doc_ids_str - async for ans in async_chat(dia, msg, True, **chat_kwargs): - last_ans = ans - if ans.get("final"): - if ans.get("answer"): - full_content = ans["answer"] - response["choices"][0]["delta"]["content"] = full_content - response["choices"][0]["delta"]["reasoning_content"] = None - yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" - final_answer = full_content - final_reference = ans.get("reference", {}) - continue - if ans.get("start_to_think"): - in_think = True - continue - if ans.get("end_to_think"): - in_think = False - continue - delta = ans.get("answer") or "" - if not delta: - continue - token_used += num_tokens_from_string(delta) - if in_think: - full_reasoning += delta - response["choices"][0]["delta"]["reasoning_content"] = delta - response["choices"][0]["delta"]["content"] = None - else: - full_content += delta - response["choices"][0]["delta"]["content"] = delta - response["choices"][0]["delta"]["reasoning_content"] = None - yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" - except Exception as e: - response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e) - yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" - - # The last chunk - response["choices"][0]["delta"]["content"] = None - response["choices"][0]["delta"]["reasoning_content"] = None - response["choices"][0]["finish_reason"] = "stop" - prompt_tokens = num_tokens_from_string(prompt) - response["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": token_used, "total_tokens": prompt_tokens + token_used} - if need_reference: - reference_payload = final_reference if final_reference is not None else last_ans.get("reference", []) - response["choices"][0]["delta"]["reference"] = _build_reference_chunks( - reference_payload, - include_metadata=include_reference_metadata, - metadata_fields=metadata_fields, - ) - response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content - yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" - yield "data:[DONE]\n\n" - - resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - else: - answer = None - chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} - if doc_ids_str: - chat_kwargs["doc_ids"] = doc_ids_str - async for ans in async_chat(dia, msg, False, **chat_kwargs): - # focus answer content only - answer = ans - break - content = answer["answer"] - - response = { - "id": f"chatcmpl-{chat_id}", - "object": "chat.completion", - "created": int(time.time()), - "model": req.get("model", ""), - "usage": { - "prompt_tokens": num_tokens_from_string(prompt), - "completion_tokens": num_tokens_from_string(content), - "total_tokens": num_tokens_from_string(prompt) + num_tokens_from_string(content), - "completion_tokens_details": { - "reasoning_tokens": context_token_used, - "accepted_prediction_tokens": num_tokens_from_string(content), - "rejected_prediction_tokens": 0, # 0 for simplicity - }, - }, - "choices": [ - { - "message": { - "role": "assistant", - "content": content, - }, - "logprobs": None, - "finish_reason": "stop", - "index": 0, - } - ], - } - if need_reference: - response["choices"][0]["message"]["reference"] = _build_reference_chunks( - answer.get("reference", {}), - include_metadata=include_reference_metadata, - metadata_fields=metadata_fields, - ) - - return jsonify(response) - - @manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821 @token_required async def delete_agent_session(tenant_id, agent_id): @@ -486,97 +136,6 @@ async def delete_agent_session(tenant_id, agent_id): return get_result() -@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821 -@token_required -async def ask_about(tenant_id): - req = await get_request_json() - if not req.get("question"): - return get_error_data_result("`question` is required.") - if not req.get("dataset_ids"): - return get_error_data_result("`dataset_ids` is required.") - if not isinstance(req.get("dataset_ids"), list): - return get_error_data_result("`dataset_ids` should be a list.") - req["kb_ids"] = req.pop("dataset_ids") - for kb_id in req["kb_ids"]: - if not KnowledgebaseService.accessible(kb_id, tenant_id): - return get_error_data_result(f"You don't own the dataset {kb_id}.") - kbs = KnowledgebaseService.query(id=kb_id) - kb = kbs[0] - if kb.chunk_num == 0: - return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") - uid = tenant_id - - async def stream(): - nonlocal req, uid - try: - async for ans in async_ask(req["question"], req["kb_ids"], uid): - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" - except Exception as e: - yield "data:" + json.dumps( - {"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - resp = Response(stream(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - -@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821 -@token_required -async def related_questions(tenant_id): - req = await get_request_json() - if not req.get("question"): - return get_error_data_result("`question` is required.") - question = req["question"] - industry = req.get("industry", "") - chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(tenant_id, chat_model_config) - prompt = """ -Objective: To generate search terms related to the user's search keywords, helping users find more valuable information. -Instructions: - - Based on the keywords provided by the user, generate 5-10 related search terms. - - Each search term should be directly or indirectly related to the keyword, guiding the user to find more valuable information. - - Use common, general terms as much as possible, avoiding obscure words or technical jargon. - - Keep the term length between 2-4 words, concise and clear. - - DO NOT translate, use the language of the original keywords. -""" - if industry: - prompt += f" - Ensure all search terms are relevant to the industry: {industry}.\n" - prompt += """ -### Example: -Keywords: Chinese football -Related search terms: -1. Current status of Chinese football -2. Reform of Chinese football -3. Youth training of Chinese football -4. Chinese football in the Asian Cup -5. Chinese football in the World Cup - -Reason: - - When searching, users often only use one or two keywords, making it difficult to fully express their information needs. - - Generating related search terms can help users dig deeper into relevant information and improve search efficiency. - - At the same time, related terms can also help search engines better understand user needs and return more accurate search results. - -""" - ans = await chat_mdl.async_chat( - prompt, - [ - { - "role": "user", - "content": f""" -Keywords: {question} -Related search terms: - """, - } - ], - {"temperature": 0.9}, - ) - return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) - @manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 async def chatbot_completions(dialog_id): @@ -968,126 +527,3 @@ async def mindmap(): return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) -@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821 -@token_required -async def sequence2txt(tenant_id): - req = await request.form - stream_mode = req.get("stream", "false").lower() == "true" - files = await request.files - if "file" not in files: - return get_error_data_result(message="Missing 'file' in multipart form-data") - - uploaded = files["file"] - - ALLOWED_EXTS = { - ".wav", ".mp3", ".m4a", ".aac", - ".flac", ".ogg", ".webm", - ".opus", ".wma" - } - - filename = uploaded.filename or "" - suffix = os.path.splitext(filename)[-1].lower() - if suffix not in ALLOWED_EXTS: - return get_error_data_result(message= - f"Unsupported audio format: {suffix}. " - f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}" - ) - fd, temp_audio_path = tempfile.mkstemp(suffix=suffix) - os.close(fd) - await uploaded.save(temp_audio_path) - - try: - default_asr_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.SPEECH2TEXT) - except Exception as e: - return get_error_data_result(message=str(e)) - asr_mdl=LLMBundle(tenant_id, default_asr_model_config) - if not stream_mode: - text = asr_mdl.transcription(temp_audio_path) - try: - os.remove(temp_audio_path) - except Exception as e: - logging.error(f"Failed to remove temp audio file: {str(e)}") - return get_json_result(data={"text": text}) - async def event_stream(): - try: - for evt in asr_mdl.stream_transcription(temp_audio_path): - yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n" - except Exception as e: - err = {"event": "error", "text": str(e)} - yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n" - finally: - try: - os.remove(temp_audio_path) - except Exception as e: - logging.error(f"Failed to remove temp audio file: {str(e)}") - - return Response(event_stream(), content_type="text/event-stream") - -@manager.route("/tts", methods=["POST"]) # noqa: F821 -@token_required -async def tts(tenant_id): - req = await get_request_json() - text = req["text"] - - try: - default_tts_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.TTS) - except Exception as e: - return get_error_data_result(message=str(e)) - tts_mdl = LLMBundle(tenant_id, default_tts_model_config) - - def stream_audio(): - try: - for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): - for chunk in tts_mdl.tts(txt): - yield chunk - except Exception as e: - yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8") - - resp = Response(stream_audio(), mimetype="audio/mpeg") - resp.headers.add_header("Cache-Control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - - return resp - - -def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None): - chunks = chunks_format(reference) - if not include_metadata: - return chunks - - doc_ids_by_kb = {} - for chunk in chunks: - kb_id = chunk.get("dataset_id") - doc_id = chunk.get("document_id") - if not kb_id or not doc_id: - continue - doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id) - - if not doc_ids_by_kb: - return chunks - - meta_by_doc = {} - for kb_id, doc_ids in doc_ids_by_kb.items(): - meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id) - if meta_map: - meta_by_doc.update(meta_map) - - if metadata_fields is not None: - metadata_fields = {f for f in metadata_fields if isinstance(f, str)} - if not metadata_fields: - return chunks - - for chunk in chunks: - doc_id = chunk.get("document_id") - if not doc_id: - continue - meta = meta_by_doc.get(doc_id) - if not meta: - continue - if metadata_fields is not None: - meta = {k: v for k, v in meta.items() if k in metadata_fields} - if meta: - chunk["document_metadata"] = meta - - return chunks diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 04d025ad4..47dccada4 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -33,7 +33,7 @@ A complete reference for RAGFlow's RESTful API. Before proceeding, please ensure ### Create chat completion -**POST** `/api/v1/chats_openai/{chat_id}/chat/completions` +**POST** `/api/v1/openai/{chat_id}/chat/completions` Creates a model response for a given chat conversation. @@ -42,7 +42,7 @@ This API follows the same request and response format as OpenAI's API. It allows #### Request - Method: POST -- URL: `/api/v1/chats_openai/{chat_id}/chat/completions` +- URL: `/api/v1/openai/{chat_id}/chat/completions` - Headers: - `'content-Type: application/json'` - `'Authorization: Bearer '` @@ -56,11 +56,11 @@ This API follows the same request and response format as OpenAI's API. It allows ```bash curl --request POST \ - --url http://{address}/api/v1/chats_openai/{chat_id}/chat/completions \ + --url http://{address}/api/v1/openai/{chat_id}/chat/completions \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer ' \ --data '{ - "model": "model", + "model": "glm-4-flash@ZHIPU-AI", "messages": [{"role": "user", "content": "Say this is a test!"}], "stream": true, "extra_body": { @@ -85,8 +85,11 @@ curl --request POST \ ##### Request Parameters +- `chat_id` (*Path parameter*) `string`, *Required* + Existing chat assistant ID. The request will use that chat assistant's knowledge and settings. + - `model` (*Body parameter*) `string`, *Required* - The model used to generate the response. The server will parse this automatically, so you can set it to any value for now. + The model used to generate the response. When `chat_id` is provided, you may also use the legacy placeholder value `"model"` to keep using the chat assistant's configured model. - `messages` (*Body parameter*) `list[object]`, *Required* A list of historical chat messages used to generate the response. This must contain at least one message with the `user` role. diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index d7a781000..f809463dc 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -46,9 +46,13 @@ Creates a model response for the given historical chat conversation via OpenAI's #### Parameters +##### chat_id: `string`, *Required* + +Existing chat assistant ID. This value is part of the request path: `/api/v1/openai//chat/completions`. + ##### model: `string`, *Required* -The model used to generate the response. The server will parse this automatically, so you can set it to any value for now. +The model used to generate the response. You may also use the legacy placeholder value `"model"` to keep using the chat assistant's configured model. ##### messages: `list[object]`, *Required* @@ -65,20 +69,12 @@ Whether to receive the response as a stream. Set this to `false` explicitly if y #### Examples -> **Note** -> Streaming via `client.chat.completions.create(stream=True, ...)` does not -> return `reference` currently because `reference` is only exposed in the -> non-stream response payload. The only way to return `reference` is non-stream -> mode with `with_raw_response`. -:::caution NOTE -Streaming via `client.chat.completions.create(stream=True, ...)` does not return `reference` because it is *only* included in the raw response payload in non-stream mode. To return `reference`, set `stream=False`. -::: ```python from openai import OpenAI import json -model = "model" -client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/") +model = "glm-4-flash@ZHIPU-AI" +client = OpenAI(api_key="ragflow-api-key", base_url="http://ragflow_address/api/v1/openai//chat") stream = True reference = True @@ -92,13 +88,11 @@ request_kwargs = dict( {"role": "user", "content": "Can you tell me how to install neovim"}, ], extra_body={ - "extra_body": { - "reference": reference, - "reference_metadata": { - "include": True, - "fields": ["author", "year", "source"], - }, - } + "reference": reference, + "reference_metadata": { + "include": True, + "fields": ["author", "year", "source"], + }, }, ) @@ -119,6 +113,8 @@ else: print("reference:", data["choices"][0]["message"].get("reference")) ``` +When `extra_body.reference` is `true`, the streamed final chunk may include `choices[0].delta.reference`, and the non-stream response may include `choices[0].message.reference`. + When `extra_body.reference_metadata.include` is `true`, each reference chunk may include a `document_metadata` object in both streaming and non-streaming responses. ## DATASET MANAGEMENT diff --git a/test/benchmark/chat.py b/test/benchmark/chat.py index cfff29c7b..7d38ebc00 100644 --- a/test/benchmark/chat.py +++ b/test/benchmark/chat.py @@ -80,7 +80,7 @@ def stream_chat_completion( t0 = time.perf_counter() response = client.request( "POST", - f"/chats_openai/{chat_id}/chat/completions", + f"/openai/{chat_id}/chat/completions", json_body=payload, stream=True, ) diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index bcfcf5541..33cb8e77d 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -336,7 +336,7 @@ def update_documents_metadata(auth, dataset_id, payload=None): # CHAT COMPLETIONS AND RELATED QUESTIONS def related_questions(auth, payload=None, *, headers=HEADERS): - url = f"{HOST_ADDRESS}/api/{VERSION}/sessions/related_questions" + url = f"{HOST_ADDRESS}/api/{VERSION}/searchbots/related_questions" res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() @@ -430,7 +430,8 @@ def chat_completions_openai(auth, chat_id, payload=None, *, headers=HEADERS): Returns: Response JSON in OpenAI chat completions format with usage information """ - url = f"{HOST_ADDRESS}/api/{VERSION}/chats_openai/{chat_id}/chat/completions" + url = f"{HOST_ADDRESS}/api/{VERSION}/openai/{chat_id}/chat/completions" + payload = dict(payload or {}) res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py index 359aa6159..9d72a63da 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py @@ -80,6 +80,15 @@ class _StubResponse: self.headers = _StubHeaders() +class _DummyUploadFile: + def __init__(self, filename): + self.filename = filename + self.saved_path = None + + async def save(self, path): + self.saved_path = path + + def _passthrough_login_required(func): @wraps(func) async def _wrapper(*args, **kwargs): @@ -130,6 +139,21 @@ def _run(coro): return asyncio.run(coro) +async def _collect_stream(body): + items = [] + if hasattr(body, "__aiter__"): + async for item in body: + if isinstance(item, bytes): + item = item.decode("utf-8") + items.append(item) + else: + for item in body: + if isinstance(item, bytes): + item = item.decode("utf-8") + items.append(item) + return items + + @pytest.fixture(scope="session") def auth(): return "unit-auth" @@ -171,6 +195,8 @@ def _load_chat_module(monkeypatch): CHAT = "chat" IMAGE2TEXT = "image2text" RERANK = "rerank" + SPEECH2TEXT = "speech2text" + TTS = "tts" class _StubRetCode(int, Enum): SUCCESS = 0 @@ -995,3 +1021,138 @@ def test_chat_session_delete_routes_partial_duplicate_unit(monkeypatch): assert res["code"] == 0 assert res["data"]["success_count"] == 1 assert res["data"]["errors"] == ["Duplicate session ids: ok"] + + +@pytest.mark.p2 +def test_chat_audio_transcription_routes_unit(monkeypatch): + module = _load_chat_module(monkeypatch) + monkeypatch.setattr(module, "Response", _StubResponse) + monkeypatch.setattr(module.tempfile, "mkstemp", lambda suffix: (11, f"/tmp/audio{suffix}")) + monkeypatch.setattr(module.os, "close", lambda _fd: None) + + def _set_request(form, files): + monkeypatch.setattr( + module, + "request", + SimpleNamespace(form=_AwaitableValue(form), files=_AwaitableValue(files)), + ) + + _set_request({"stream": "false"}, {}) + res = _run(module.transcription.__wrapped__()) + assert "Missing 'file' in multipart form-data" in res["message"] + + _set_request({"stream": "false"}, {"file": _DummyUploadFile("bad.txt")}) + res = _run(module.transcription.__wrapped__()) + assert "Unsupported audio format: .txt" in res["message"] + + _set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr( + module, + "get_tenant_default_model_by_type", + lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found!")), + ) + res = _run(module.transcription.__wrapped__()) + assert res["message"] == "Tenant not found!" + + _set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr( + module, + "get_tenant_default_model_by_type", + lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default ASR model is set")), + ) + res = _run(module.transcription.__wrapped__()) + assert res["message"] == "No default ASR model is set" + + class _SyncASR: + def transcription(self, _path): + return "transcribed text" + + def stream_transcription(self, _path): + return [] + + _set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "asr-x"}) + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncASR()) + monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup fail"))) + res = _run(module.transcription.__wrapped__()) + assert res["code"] == 0 + assert res["data"]["text"] == "transcribed text" + + class _StreamASR: + def transcription(self, _path): + return "" + + def stream_transcription(self, _path): + yield {"event": "partial", "text": "hello"} + + _set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _StreamASR()) + monkeypatch.setattr(module.os, "remove", lambda _path: None) + resp = _run(module.transcription.__wrapped__()) + assert isinstance(resp, _StubResponse) + assert resp.content_type == "text/event-stream" + chunks = _run(_collect_stream(resp.body)) + assert any('"event": "partial"' in chunk for chunk in chunks) + + class _ErrorASR: + def transcription(self, _path): + return "" + + def stream_transcription(self, _path): + raise RuntimeError("stream asr boom") + + _set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _ErrorASR()) + monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup boom"))) + resp = _run(module.transcription.__wrapped__()) + chunks = _run(_collect_stream(resp.body)) + assert any("stream asr boom" in chunk for chunk in chunks) + + +@pytest.mark.p2 +def test_chat_audio_speech_routes_unit(monkeypatch): + module = _load_chat_module(monkeypatch) + monkeypatch.setattr(module, "Response", _StubResponse) + _set_request_json(monkeypatch, module, {"text": "A。B"}) + + monkeypatch.setattr( + module, + "get_tenant_default_model_by_type", + lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found!")), + ) + res = _run(module.tts.__wrapped__()) + assert res["message"] == "Tenant not found!" + + monkeypatch.setattr( + module, + "get_tenant_default_model_by_type", + lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default TTS model is set")), + ) + res = _run(module.tts.__wrapped__()) + assert res["message"] == "No default TTS model is set" + + class _TTSOk: + def tts(self, txt): + if not txt: + return [] + yield f"chunk-{txt}".encode("utf-8") + + monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "tts-x"}) + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk()) + resp = _run(module.tts.__wrapped__()) + assert resp.mimetype == "audio/mpeg" + assert resp.headers.get("Cache-Control") == "no-cache" + assert resp.headers.get("Connection") == "keep-alive" + assert resp.headers.get("X-Accel-Buffering") == "no" + chunks = _run(_collect_stream(resp.body)) + assert any("chunk-A" in chunk for chunk in chunks) + assert any("chunk-B" in chunk for chunk in chunks) + + class _TTSErr: + def tts(self, _txt): + raise RuntimeError("tts boom") + + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSErr()) + resp = _run(module.tts.__wrapped__()) + chunks = _run(_collect_stream(resp.body)) + assert any('"code": 500' in chunk and "**ERROR**: tts boom" in chunk for chunk in chunks) diff --git a/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py b/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py index 54d5fe29d..4df694dc6 100644 --- a/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py +++ b/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py @@ -59,7 +59,7 @@ class TestChatCompletionsOpenAI: HttpApiAuth, chat_id, { - "model": "model", # Required by OpenAI-compatible API, value is ignored by RAGFlow + "model": "model", # Legacy placeholder keeps using the chat assistant's configured model "messages": [{"role": "user", "content": "hello"}], "stream": False, }, @@ -100,7 +100,7 @@ class TestChatCompletionsOpenAI: HttpApiAuth, chat_id, { - "model": "model", # Required by OpenAI-compatible API, value is ignored by RAGFlow + "model": "model", # Legacy placeholder keeps using the chat assistant's configured model "messages": [{"role": "user", "content": "hello"}], "stream": False, }, @@ -123,7 +123,7 @@ class TestChatCompletionsOpenAI: HttpApiAuth, "invalid_chat_id", { - "model": "model", # Required by OpenAI-compatible API, value is ignored by RAGFlow + "model": "model", # Legacy placeholder keeps using the chat assistant's configured model "messages": [{"role": "user", "content": "hello"}], "stream": False, }, diff --git a/test/testcases/test_http_api/test_session_management/test_related_questions.py b/test/testcases/test_http_api/test_session_management/test_related_questions.py index 427708b27..c70322ddf 100644 --- a/test/testcases/test_http_api/test_session_management/test_related_questions.py +++ b/test/testcases/test_http_api/test_session_management/test_related_questions.py @@ -29,11 +29,11 @@ class TestRelatedQuestions: @pytest.mark.p2 def test_related_questions_missing_question(self, HttpApiAuth): res = related_questions(HttpApiAuth, {"industry": "search"}) - assert res["code"] == 102, res + assert res["code"] == 101, res assert "question" in res.get("message", ""), res @pytest.mark.p2 def test_related_questions_invalid_auth(self): res = related_questions(RAGFlowHttpApiAuth(INVALID_API_TOKEN), {"question": "ragflow", "industry": "search"}) - assert res["code"] == 109, res + assert res["code"] == 102, res assert "API key is invalid" in res.get("message", ""), res diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index 9834b28e2..53973614f 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -667,6 +667,34 @@ def _load_agent_api_module(monkeypatch): return module +def _load_openai_api_module(monkeypatch): + _load_session_module(monkeypatch) + repo_root = Path(__file__).resolve().parents[4] + + api_apps_mod = ModuleType("api.apps") + api_apps_mod.__path__ = [str(repo_root / "api" / "apps")] + api_apps_mod.login_required = lambda func: func + api_apps_mod.current_user = SimpleNamespace(id="tenant-1") + monkeypatch.setitem(sys.modules, "api.apps", api_apps_mod) + + api_apps_restful_mod = ModuleType("api.apps.restful_apis") + api_apps_restful_mod.__path__ = [str(repo_root / "api" / "apps" / "restful_apis")] + monkeypatch.setitem(sys.modules, "api.apps.restful_apis", api_apps_restful_mod) + + quart_mod = ModuleType("quart") + quart_mod.Response = _StubResponse + quart_mod.jsonify = lambda payload: payload + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + module_path = repo_root / "api" / "apps" / "restful_apis" / "openai_api.py" + spec = importlib.util.spec_from_file_location("test_openai_api_unit_module", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, "test_openai_api_unit_module", module) + spec.loader.exec_module(module) + return module + + @pytest.mark.p2 def test_create_and_update_guard_matrix(monkeypatch): module = _load_session_module(monkeypatch) @@ -687,62 +715,16 @@ def test_create_and_update_guard_matrix(monkeypatch): assert res["message"] == "You cannot access the agent." -@pytest.mark.p2 -def test_chat_completion_metadata_and_stream_paths(monkeypatch): - module = _load_session_module(monkeypatch) - - monkeypatch.setattr(module, "Response", _StubResponse) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(kb_ids=["kb-1"])]) - monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"id": "doc-1"}]) - monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", [])) - monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: []) - - captured_requests = [] - - async def fake_rag_completion(_tenant_id, _chat_id, **req): - captured_requests.append(req) - yield {"answer": "ok"} - - monkeypatch.setattr(module, "rag_completion", fake_rag_completion) - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(None)) - resp = _run(inspect.unwrap(module.chat_completion)("tenant-1", "chat-1")) - assert isinstance(resp, _StubResponse) - assert resp.headers.get("Content-Type") == "text/event-stream; charset=utf-8" - _run(_collect_stream(resp.body)) - assert captured_requests[-1].get("question") == "" - - req_with_conditions = { - "question": "hello", - "session_id": "session-1", - "metadata_condition": {"logic": "and", "conditions": [{"name": "author", "value": "bob"}]}, - "stream": True, - } - monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: [SimpleNamespace(id="session-1")]) - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(req_with_conditions)) - resp = _run(inspect.unwrap(module.chat_completion)("tenant-1", "chat-1")) - _run(_collect_stream(resp.body)) - assert captured_requests[-1].get("doc_ids") == "-999" - - req_without_conditions = { - "question": "hello", - "session_id": "session-1", - "metadata_condition": {"logic": "and", "conditions": []}, - "stream": True, - "doc_ids": "legacy", - } - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(req_without_conditions)) - resp = _run(inspect.unwrap(module.chat_completion)("tenant-1", "chat-1")) - _run(_collect_stream(resp.body)) - assert "doc_ids" not in captured_requests[-1] - - @pytest.mark.p2 def test_openai_chat_validation_matrix_unit(monkeypatch): - module = _load_session_module(monkeypatch) + module = _load_openai_api_module(monkeypatch) monkeypatch.setattr(module, "num_tokens_from_string", lambda _text: 1) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(kb_ids=["kb-1"])]) + monkeypatch.setattr( + module.DialogService, + "query", + lambda **_kwargs: [SimpleNamespace(kb_ids=["kb-1"], llm_id="chat-model", tenant_id="tenant-1")], + ) cases = [ ( @@ -786,20 +768,23 @@ def test_openai_chat_validation_matrix_unit(monkeypatch): for payload, expected in cases: monkeypatch.setattr(module, "get_request_json", lambda p=payload: _AwaitableValue(p)) - res = _run(inspect.unwrap(module.chat_completion_openai_like)("tenant-1", "chat-1")) + res = _run(inspect.unwrap(module.openai_chat_completions)("chat-1")) assert expected in res["message"] @pytest.mark.p2 def test_openai_stream_generator_branches_unit(monkeypatch): - module = _load_session_module(monkeypatch) + module = _load_openai_api_module(monkeypatch) - monkeypatch.setattr(module, "Response", _StubResponse) monkeypatch.setattr(module, "num_tokens_from_string", lambda text: len(text or "")) monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", [])) monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: []) monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"id": "doc-1"}]) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(kb_ids=["kb-1"])]) + monkeypatch.setattr( + module.DialogService, + "query", + lambda **_kwargs: [SimpleNamespace(kb_ids=["kb-1"], llm_id="chat-model", tenant_id="tenant-1")], + ) monkeypatch.setattr(module, "_build_reference_chunks", lambda *_args, **_kwargs: [{"id": "ref-1"}]) async def fake_async_chat(_dia, _msg, _stream, **_kwargs): @@ -829,7 +814,7 @@ def test_openai_stream_generator_branches_unit(monkeypatch): } monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload)) - resp = _run(inspect.unwrap(module.chat_completion_openai_like)("tenant-1", "chat-1")) + resp = _run(inspect.unwrap(module.openai_chat_completions)("chat-1")) assert isinstance(resp, _StubResponse) assert resp.headers.get("Content-Type") == "text/event-stream; charset=utf-8" @@ -843,11 +828,14 @@ def test_openai_stream_generator_branches_unit(monkeypatch): @pytest.mark.p2 def test_openai_nonstream_branch_unit(monkeypatch): - module = _load_session_module(monkeypatch) + module = _load_openai_api_module(monkeypatch) - monkeypatch.setattr(module, "jsonify", lambda payload: payload) monkeypatch.setattr(module, "num_tokens_from_string", lambda text: len(text or "")) - monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(kb_ids=[])]) + monkeypatch.setattr( + module.DialogService, + "query", + lambda **_kwargs: [SimpleNamespace(kb_ids=[], llm_id="chat-model", tenant_id="tenant-1")], + ) async def fake_async_chat(_dia, _msg, _stream, **_kwargs): yield {"answer": "world", "reference": {}} @@ -865,7 +853,7 @@ def test_openai_nonstream_branch_unit(monkeypatch): ), ) - res = _run(inspect.unwrap(module.chat_completion_openai_like)("tenant-1", "chat-1")) + res = _run(inspect.unwrap(module.openai_chat_completions)("chat-1")) assert res["choices"][0]["message"]["content"] == "world" @@ -1115,92 +1103,6 @@ def test_delete_agent_session_error_matrix_unit(monkeypatch): assert res["data"]["errors"] == ["Duplicate session ids: ok"] -@pytest.mark.p2 -def test_sessions_ask_route_validation_and_stream_unit(monkeypatch): - module = _load_session_module(monkeypatch) - monkeypatch.setattr(module, "Response", _StubResponse) - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"dataset_ids": ["kb-1"]})) - res = _run(inspect.unwrap(module.ask_about)("tenant-1")) - assert res["message"] == "`question` is required." - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"question": "q"})) - res = _run(inspect.unwrap(module.ask_about)("tenant-1")) - assert res["message"] == "`dataset_ids` is required." - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"question": "q", "dataset_ids": "kb-1"})) - res = _run(inspect.unwrap(module.ask_about)("tenant-1")) - assert res["message"] == "`dataset_ids` should be a list." - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False) - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"question": "q", "dataset_ids": ["kb-1"]})) - res = _run(inspect.unwrap(module.ask_about)("tenant-1")) - assert res["message"] == "You don't own the dataset kb-1." - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [SimpleNamespace(chunk_num=0)]) - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"question": "q", "dataset_ids": ["kb-1"]})) - res = _run(inspect.unwrap(module.ask_about)("tenant-1")) - assert res["message"] == "The dataset kb-1 doesn't own parsed file" - - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [SimpleNamespace(chunk_num=1)]) - captured = {} - - async def _streaming_async_ask(question, kb_ids, uid): - captured["question"] = question - captured["kb_ids"] = kb_ids - captured["uid"] = uid - yield {"answer": "first"} - raise RuntimeError("ask stream boom") - - monkeypatch.setattr(module, "async_ask", _streaming_async_ask) - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"question": "q", "dataset_ids": ["kb-1"]})) - resp = _run(inspect.unwrap(module.ask_about)("tenant-1")) - assert isinstance(resp, _StubResponse) - assert resp.headers.get("Content-Type") == "text/event-stream; charset=utf-8" - chunks = _run(_collect_stream(resp.body)) - assert any('"answer": "first"' in chunk for chunk in chunks) - assert any('"code": 500' in chunk and "**ERROR**: ask stream boom" in chunk for chunk in chunks) - assert '"data": true' in chunks[-1].lower() - assert captured == {"question": "q", "kb_ids": ["kb-1"], "uid": "tenant-1"} - - -@pytest.mark.p2 -def test_sessions_related_questions_prompt_build_unit(monkeypatch): - module = _load_session_module(monkeypatch) - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) - res = _run(inspect.unwrap(module.related_questions)("tenant-1")) - assert res["message"] == "`question` is required." - - captured = {} - - class _FakeLLMBundle: - def __init__(self, *args, **kwargs): - captured["bundle_args"] = args - captured["bundle_kwargs"] = kwargs - - async def async_chat(self, prompt, messages, options): - captured["prompt"] = prompt - captured["messages"] = messages - captured["options"] = options - return "1. First related\n2. Second related\nplain text" - - monkeypatch.setattr(module, "LLMBundle", _FakeLLMBundle) - monkeypatch.setattr( - module, - "get_request_json", - lambda: _AwaitableValue({"question": "solar energy", "industry": "renewables"}), - ) - res = _run(inspect.unwrap(module.related_questions)("tenant-1")) - assert res["data"] == ["First related", "Second related"] - assert "Keep the term length between 2-4 words" in captured["prompt"] - assert "related terms can also help search engines" in captured["prompt"] - assert "Ensure all search terms are relevant to the industry: renewables." in captured["prompt"] - assert "Keywords: solar energy" in captured["messages"][0]["content"] - assert captured["options"] == {"temperature": 0.9} - - @pytest.mark.p2 def test_chatbot_routes_auth_stream_nonstream_unit(monkeypatch): module = _load_session_module(monkeypatch) @@ -1701,133 +1603,9 @@ def test_searchbots_mindmap_embedded_matrix_unit(monkeypatch): assert "mindmap boom" in res["message"] -@pytest.mark.p2 -def test_sequence2txt_embedded_validation_and_stream_matrix_unit(monkeypatch): - module = _load_session_module(monkeypatch) - handler = inspect.unwrap(module.sequence2txt) - monkeypatch.setattr(module, "Response", _StubResponse) - monkeypatch.setattr(module.tempfile, "mkstemp", lambda suffix: (11, f"/tmp/audio{suffix}")) - monkeypatch.setattr(module.os, "close", lambda _fd: None) - - def _set_request(form, files): - monkeypatch.setattr( - module, - "request", - SimpleNamespace(form=_AwaitableValue(form), files=_AwaitableValue(files)), - ) - - _set_request({"stream": "false"}, {}) - res = _run(handler("tenant-1")) - assert "Missing 'file' in multipart form-data" in res["message"] - - _set_request({"stream": "false"}, {"file": _DummyUploadFile("bad.txt")}) - res = _run(handler("tenant-1")) - assert "Unsupported audio format: .txt" in res["message"] - - _set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")}) - tenant_llm_service = sys.modules["api.db.services.tenant_llm_service"] - monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (False, None)) - res = _run(handler("tenant-1")) - assert res["message"] == "Tenant not found!" - - _set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")}) - tenant_llm_service = sys.modules["api.db.services.tenant_llm_service"] - monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="", tts_id="", llm_id="", embd_id="", img2txt_id="", rerank_id=""))) - res = _run(handler("tenant-1")) - assert res["message"] == "No default ASR model is set" - - class _SyncASR: - def transcription(self, _path): - return "transcribed text" - - def stream_transcription(self, _path): - return [] - - _set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")}) - monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="asr-x", tts_id="", llm_id="", embd_id="", img2txt_id="", rerank_id=""))) - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncASR()) - monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup fail"))) - res = _run(handler("tenant-1")) - assert res["code"] == 0 - assert res["data"]["text"] == "transcribed text" - - class _StreamASR: - def transcription(self, _path): - return "" - - def stream_transcription(self, _path): - yield {"event": "partial", "text": "hello"} - - _set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")}) - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _StreamASR()) - monkeypatch.setattr(module.os, "remove", lambda _path: None) - resp = _run(handler("tenant-1")) - assert isinstance(resp, _StubResponse) - assert resp.content_type == "text/event-stream" - chunks = _run(_collect_stream(resp.body)) - assert any('"event": "partial"' in chunk for chunk in chunks) - - class _ErrorASR: - def transcription(self, _path): - return "" - - def stream_transcription(self, _path): - raise RuntimeError("stream asr boom") - - _set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")}) - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _ErrorASR()) - monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup boom"))) - resp = _run(handler("tenant-1")) - chunks = _run(_collect_stream(resp.body)) - assert any("stream asr boom" in chunk for chunk in chunks) - - -@pytest.mark.p2 -def test_tts_embedded_stream_and_error_matrix_unit(monkeypatch): - module = _load_session_module(monkeypatch) - handler = inspect.unwrap(module.tts) - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"text": "A。B"})) - monkeypatch.setattr(module, "Response", _StubResponse) - - tenant_llm_service = sys.modules["api.db.services.tenant_llm_service"] - monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (False, None)) - res = _run(handler("tenant-1")) - assert res["message"] == "Tenant not found!" - - monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="", tts_id="", llm_id="", embd_id="", img2txt_id="", rerank_id=""))) - res = _run(handler("tenant-1")) - assert res["message"] == "No default TTS model is set" - - class _TTSOk: - def tts(self, txt): - if not txt: - return [] - yield f"chunk-{txt}".encode("utf-8") - - monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="", tts_id="tts-x", llm_id="", embd_id="", img2txt_id="", rerank_id=""))) - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk()) - resp = _run(handler("tenant-1")) - assert resp.mimetype == "audio/mpeg" - assert resp.headers.get("Cache-Control") == "no-cache" - assert resp.headers.get("Connection") == "keep-alive" - assert resp.headers.get("X-Accel-Buffering") == "no" - chunks = _run(_collect_stream(resp.body)) - assert any("chunk-A" in chunk for chunk in chunks) - assert any("chunk-B" in chunk for chunk in chunks) - - class _TTSErr: - def tts(self, _txt): - raise RuntimeError("tts boom") - - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSErr()) - resp = _run(handler("tenant-1")) - chunks = _run(_collect_stream(resp.body)) - assert any('"code": 500' in chunk and "**ERROR**: tts boom" in chunk for chunk in chunks) - - @pytest.mark.p2 def test_build_reference_chunks_metadata_matrix_unit(monkeypatch): - module = _load_session_module(monkeypatch) + module = _load_openai_api_module(monkeypatch) monkeypatch.setattr(module, "chunks_format", lambda _reference: [{"dataset_id": "kb-1", "document_id": "doc-1"}]) res = module._build_reference_chunks([], include_metadata=False)