Fix : Web API tests by normalizing errors, validation, and uploads (#12620)

### What problem does this PR solve?

Fixes web API behavior mismatches that caused test failures by
normalizing error responses, tightening validations, correcting error
messages, and closing upload file handles.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
6ba3i
2026-01-16 11:09:22 +08:00
committed by GitHub
parent 59f4c51222
commit 2b20d0b3bb
13 changed files with 240 additions and 97 deletions

View File

@ -16,21 +16,23 @@
import logging
import os
import sys
import time
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from quart import Blueprint, Quart, request, g, current_app, session
from quart import Blueprint, Quart, request, g, current_app, session, jsonify
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors
from common.constants import StatusEnum
from common.constants import StatusEnum, RetCode
from api.db.db_models import close_connection, APIToken
from api.db.services import UserService
from api.utils.json_encode import CustomJSONEncoder
from api.utils import commands
from quart_auth import Unauthorized
from quart_auth import Unauthorized as QuartAuthUnauthorized
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
from quart_schema import QuartSchema
from common import settings
from api.utils.api_utils import server_error_response
from api.utils.api_utils import server_error_response, get_json_result
from api.constants import API_VERSION
from common.misc_utils import get_uuid
@ -38,6 +40,22 @@ settings.init_settings()
__all__ = ["app"]
UNAUTHORIZED_MESSAGE = "<Unauthorized '401: Unauthorized'>"
def _unauthorized_message(error):
if error is None:
return UNAUTHORIZED_MESSAGE
try:
msg = repr(error)
except Exception:
return UNAUTHORIZED_MESSAGE
if msg == UNAUTHORIZED_MESSAGE:
return msg
if "Unauthorized" in msg and "401" in msg:
return msg
return UNAUTHORIZED_MESSAGE
app = Quart(__name__)
app = cors(app, allow_origin="*")
@ -145,10 +163,18 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if not current_user: # or not session.get("_user_id"):
raise Unauthorized()
else:
return await current_app.ensure_async(func)(*args, **kwargs)
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
user = current_user
if timing_enabled:
logging.info(
"api_timing login_required auth_ms=%.2f path=%s",
(time.perf_counter() - t_start) * 1000,
request.path,
)
if not user: # or not session.get("_user_id"):
raise QuartAuthUnauthorized()
return await current_app.ensure_async(func)(*args, **kwargs)
return wrapper
@ -258,12 +284,33 @@ client_urls_prefix = [
@app.errorhandler(404)
async def not_found(error):
error_msg: str = f"The requested URL {request.path} was not found"
logging.error(error_msg)
return {
logging.error(f"The requested URL {request.path} was not found")
message = f"Not Found: {request.path}"
response = {
"code": RetCode.NOT_FOUND,
"message": message,
"data": None,
"error": "Not Found",
"message": error_msg,
}, 404
}
return jsonify(response), RetCode.NOT_FOUND
@app.errorhandler(401)
async def unauthorized(error):
logging.warning("Unauthorized request")
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.errorhandler(QuartAuthUnauthorized)
async def unauthorized_quart_auth(error):
logging.warning("Unauthorized request (quart_auth)")
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED
@app.errorhandler(WerkzeugUnauthorized)
async def unauthorized_werkzeug(error):
logging.warning("Unauthorized request (werkzeug)")
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.teardown_request
def _db_close(exception):

View File

@ -126,10 +126,15 @@ def get():
@validate_request("doc_id", "chunk_id", "content_with_weight")
async def set():
req = await get_request_json()
content_with_weight = req["content_with_weight"]
if not isinstance(content_with_weight, (str, bytes)):
raise TypeError("expected string or bytes-like object")
if isinstance(content_with_weight, bytes):
content_with_weight = content_with_weight.decode("utf-8", errors="ignore")
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
"content_with_weight": content_with_weight}
d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_kwd" in req:
if not isinstance(req["important_kwd"], list):
@ -171,7 +176,7 @@ async def set():
_d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
_d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
@ -223,14 +228,27 @@ async def rm():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
# Include doc_id in condition to properly scope the delete
condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]}
if not settings.docStoreConn.delete(condition,
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
try:
deleted_count = settings.docStoreConn.delete(condition,
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id)
except Exception:
return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids)
if isinstance(deleted_chunk_ids, list):
unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids))
has_ids = len(unique_chunk_ids) > 0
else:
unique_chunk_ids = [deleted_chunk_ids]
has_ids = deleted_chunk_ids not in (None, "")
if has_ids and deleted_count == 0:
return get_data_error_result(message="Index updating failure")
if deleted_count > 0 and deleted_count < len(unique_chunk_ids):
deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id)
chunk_number = deleted_count
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):

View File

@ -42,13 +42,18 @@ async def set_dialog():
if len(name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()):
name = name.strip()
name = duplicate_name(
DialogService.query,
name=name,
tenant_id=current_user.id,
status=StatusEnum.VALID.value)
name = name.strip()
if is_create:
existing_names = {
d.name.casefold()
for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
if d.name
}
if name.casefold() in existing_names:
def _name_exists(name: str, **_kwargs) -> bool:
return name.casefold() in existing_names
name = duplicate_name(_name_exists, name=name)
description = req.get("description", "A helpful dialog")
icon = req.get("icon", "")
@ -63,16 +68,15 @@ async def set_dialog():
meta_data_filter = req.get("meta_data_filter", {})
prompt_config = req["prompt_config"]
if not is_create:
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
for p in prompt_config["parameters"]:
if p["optional"]:
continue
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
return get_data_error_result(
message="Parameter '{}' is not used".format(p["key"]))
for p in prompt_config.get("parameters", []):
if p["optional"]:
continue
if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0:
return get_data_error_result(
message="Parameter '{}' is not used".format(p["key"]))
try:
e, tenant = TenantService.get_by_id(current_user.id)

View File

@ -62,10 +62,21 @@ async def upload():
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = files.getlist("file")
def _close_file_objs(objs):
for obj in objs:
try:
obj.close()
except Exception:
try:
obj.stream.close()
except Exception:
pass
for file_obj in file_objs:
if file_obj.filename == "":
_close_file_objs(file_objs)
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
_close_file_objs(file_objs)
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)

View File

@ -14,6 +14,8 @@
# limitations under the License.
#
import logging
import os
import time
from quart import request
from api.apps import login_required, current_user
@ -35,22 +37,56 @@ from common.constants import MemoryType, RetCode, ForgettingPolicy
@login_required
@validate_request("name", "memory_type", "embd_id", "llm_id")
async def create_memory():
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
req = await get_request_json()
t_parsed = time.perf_counter() if timing_enabled else None
# check name length
name = req["name"]
memory_name = name.strip()
if len(memory_name) == 0:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory name cannot be empty or whitespace.")
if len(memory_name) > MEMORY_NAME_LIMIT:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
# check memory_type valid
if not isinstance(req["memory_type"], list):
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory type must be a list.")
memory_type = set(req["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
memory_type = list(memory_type)
try:
t_before_db = time.perf_counter() if timing_enabled else None
res, memory = MemoryService.create_memory(
tenant_id=current_user.id,
name=memory_name,
@ -58,6 +94,15 @@ async def create_memory():
embd_id=req["embd_id"],
llm_id=req["llm_id"]
)
if timing_enabled:
logging.info(
"api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(t_before_db - t_parsed) * 1000,
(time.perf_counter() - t_before_db) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory))

View File

@ -445,6 +445,7 @@ class DocumentService(CommonService):
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
return list(docs.dicts())
@ -936,6 +937,8 @@ class DocumentService(CommonService):
bad = 0
e, doc = DocumentService.get_by_id(d["id"])
status = doc.run # TaskStatus.RUNNING.value
if status == TaskStatus.CANCEL.value:
continue
doc_progress = doc.progress if doc and doc.progress else 0.0
special_task_running = False
priority = 0
@ -979,7 +982,16 @@ class DocumentService(CommonService):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
cls.update_by_id(d["id"], info)
info["update_time"] = current_timestamp()
info["update_date"] = get_format_time()
(
cls.model.update(info)
.where(
(cls.model.id == d["id"])
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
)
.execute()
)
except Exception as e:
if str(e).find("'0'") < 0:
logging.exception("fetch task exception")
@ -1012,7 +1024,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
# cancelled: run == "2" but progress can vary
# cancelled: run == "2"
cancelled = (
cls.model.select(fn.COUNT(1))
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))

View File

@ -397,7 +397,7 @@ class KnowledgebaseService(CommonService):
if dataset_name == "":
return False, get_data_error_result(message="Dataset name can't be empty.")
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}")
# Deduplicate name within tenant
dataset_name = duplicate_name(

View File

@ -31,6 +31,12 @@ from quart import (
jsonify,
request
)
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
try:
from quart.exceptions import BadRequest as QuartBadRequest
except ImportError: # pragma: no cover - optional dependency
QuartBadRequest = None
from peewee import OperationalError
@ -48,35 +54,33 @@ requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSON
async def _coerce_request_data() -> dict:
"""Fetch JSON body with sane defaults; fallback to form data."""
if hasattr(request, "_cached_payload"):
return request._cached_payload
payload: Any = None
last_error: Exception | None = None
try:
payload = await request.get_json(force=True, silent=True)
except Exception as e:
last_error = e
payload = None
body_bytes = await request.get_data()
has_body = bool(body_bytes)
content_type = (request.content_type or "").lower()
is_json = content_type.startswith("application/json")
if payload is None:
try:
form = await request.form
payload = form.to_dict()
except Exception as e:
last_error = e
payload = None
if not has_body:
payload = {}
elif is_json:
payload = await request.get_json(force=False, silent=False)
if isinstance(payload, dict):
payload = payload or {}
elif isinstance(payload, str):
raise AttributeError("'str' object has no attribute 'get'")
else:
raise TypeError("JSON payload must be an object.")
else:
form = await request.form
payload = form.to_dict() if form else None
if payload is None:
raise TypeError("Request body is not a valid form payload.")
if payload is None:
if last_error is not None:
raise last_error
raise ValueError("No JSON body or form data found in request.")
if isinstance(payload, dict):
return payload or {}
if isinstance(payload, str):
raise AttributeError("'str' object has no attribute 'get'")
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
request._cached_payload = payload
return payload
async def get_request_json():
return await _coerce_request_data()
@ -124,16 +128,12 @@ def server_error_response(e):
try:
msg = repr(e).lower()
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
resp = get_json_result(code=RetCode.UNAUTHORIZED, message="Unauthorized")
resp.status_code = RetCode.UNAUTHORIZED
return resp
except Exception as ex:
logging.warning(f"error checking authorization: {ex}")
if len(e.args) > 1:
try:
serialized_data = serialize_for_json(e.args[1])
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
except Exception:
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
@ -168,7 +168,17 @@ def validate_request(*args, **kwargs):
def wrapper(func):
@wraps(func)
async def decorated_function(*_args, **_kwargs):
errs = process_args(await _coerce_request_data())
exception_types = (AttributeError, TypeError, WerkzeugBadRequest)
if QuartBadRequest is not None:
exception_types = exception_types + (QuartBadRequest,)
if args or kwargs:
try:
input_arguments = await _coerce_request_data()
except exception_types:
input_arguments = {}
else:
input_arguments = await _coerce_request_data()
errs = process_args(input_arguments)
if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func):

View File

@ -318,6 +318,8 @@ class RAGFlow:
for data in res["data"]["memory_list"]:
result_list.append(Memory(self, data))
return {
"code": res.get("code", 0),
"message": res.get("message"),
"memory_list": result_list,
"total_count": res["data"]["total_count"]
}

View File

@ -99,7 +99,7 @@ def batch_create_datasets(auth, num):
# DOCUMENT APP
def upload_documents(auth, payload=None, files_path=None):
def upload_documents(auth, payload=None, files_path=None, *, filename_override=None):
url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload"
if files_path is None:
@ -115,7 +115,8 @@ def upload_documents(auth, payload=None, files_path=None):
for fp in files_path:
p = Path(fp)
f = p.open("rb")
fields.append(("file", (p.name, f)))
filename = filename_override if filename_override is not None else p.name
fields.append(("file", (filename, f)))
file_objects.append(f)
m = MultipartEncoder(fields=fields)

View File

@ -14,7 +14,8 @@
# limitations under the License.
#
from time import sleep
from ragflow_sdk import RAGFlow
from configs import HOST_ADDRESS, VERSION
import pytest
from common import (
batch_add_chunks,
@ -81,7 +82,9 @@ def generate_test_files(request: FixtureRequest, tmp_path):
def ragflow_tmp_dir(request, tmp_path_factory):
class_name = request.cls.__name__
return tmp_path_factory.mktemp(class_name)
@pytest.fixture(scope="session")
def client(token: str) -> RAGFlow:
return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION)
@pytest.fixture(scope="session")
def WebApiAuth(auth):

View File

@ -265,11 +265,11 @@ class TestChunksRetrieval:
@pytest.mark.parametrize(
"payload, expected_code, expected_highlight, expected_message",
[
({"highlight": True}, 0, True, ""),
({"highlight": "True"}, 0, True, ""),
pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
pytest.param({"highlight": "False"}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
pytest.param({"highlight": True}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
pytest.param({"highlight": "True"}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
({"highlight": False}, 0, False, ""),
({"highlight": "False"}, 0, False, ""),
({"highlight": None}, 0, False, "")
],
)
def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message):

View File

@ -17,11 +17,9 @@ import string
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
import requests
from common import DOCUMENT_APP_URL, list_kbs, upload_documents
from configs import DOCUMENT_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN
from common import list_kbs, upload_documents
from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
from requests_toolbelt import MultipartEncoder
from utils.file_utils import create_txt_file
@ -111,17 +109,9 @@ class TestDocumentsUpload:
kb_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt")
url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload"
fields = [("file", ("", fp.open("rb"))), ("kb_id", kb_id)]
m = MultipartEncoder(fields=fields)
res = requests.post(
url=url,
headers={"Content-Type": m.content_type},
auth=WebApiAuth,
data=m,
)
assert res.json()["code"] == 101, res
assert res.json()["message"] == "No file selected!", res
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp], filename_override="")
assert res["code"] == 101, res
assert res["message"] == "No file selected!", res
@pytest.mark.p2
def test_filename_exceeds_max_length(self, WebApiAuth, add_dataset_func, tmp_path):