Feat: add new tests and tescases for restful api suite (#15038)

### What problem does this PR solve?

extend restful api suite

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Other (please describe): test
This commit is contained in:
Idriss Sbaaoui
2026-05-20 14:56:55 +08:00
committed by GitHub
parent 2836a934b5
commit aea90f4e39
3 changed files with 1391 additions and 7 deletions

View File

@ -14,7 +14,12 @@
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import os
import pytest
from test.testcases.configs import INVALID_API_TOKEN, INVALID_ID_32
from test.testcases.restful_api.helpers.client import RestClient
from test.testcases.utils import wait_for
def _assert_created_chunk_id(payload):
@ -25,6 +30,41 @@ def _assert_created_chunk_id(payload):
return chunk_id
@wait_for(10, 1, "Chunk indexing timeout in RESTful batch 09 tests")
def _chunk_count(rest_client, base_path, expected_total):
res = rest_client.get(base_path)
if res.status_code != 200:
return False
payload = res.json()
if payload["code"] != 0:
return False
return payload["data"]["total"] == expected_total and len(payload["data"]["chunks"]) == min(expected_total, 30)
def _reset_chunk_batch(rest_client, base_path, count=4):
cleanup_res = rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True})
assert cleanup_res.status_code == 200, cleanup_res.text
cleanup_payload = cleanup_res.json()
assert cleanup_payload["code"] == 0, cleanup_payload
baseline_res = rest_client.post(base_path, json={"content": "ragflow test upload"})
assert baseline_res.status_code == 200, baseline_res.text
baseline_payload = baseline_res.json()
assert baseline_payload["code"] == 0, baseline_payload
baseline_id = _assert_created_chunk_id(baseline_payload)
chunk_ids = []
for index in range(count):
res = rest_client.post(base_path, json={"content": f"chunk test {index}"})
assert res.status_code == 200, (index, res.text)
payload = res.json()
assert payload["code"] == 0, (index, payload)
chunk_ids.append(_assert_created_chunk_id(payload))
_chunk_count(rest_client, base_path, count + 1)
return baseline_id, chunk_ids
@pytest.mark.p1
def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document):
dataset_id, document_id = create_document("chunk_cycle.txt")
@ -88,6 +128,42 @@ def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document):
assert deleted_get_payload["code"] != 0, deleted_get_payload
@pytest.mark.p1
def test_chunk_add_requires_auth(create_document):
dataset_id, document_id = create_document("chunk_add_auth.txt")
path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.post(path, json={"content": "chunk test"})
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
@pytest.mark.p1
def test_chunk_delete_requires_auth(create_document):
dataset_id, document_id = create_document("chunk_delete_auth.txt")
path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.delete(path, json={"chunk_ids": []})
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
@pytest.mark.p1
def test_chunk_list_requires_auth(create_document):
dataset_id, document_id = create_document("chunk_list_auth.txt")
path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.get(path)
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
@pytest.mark.p2
def test_chunks_add_requires_content(rest_client, create_document):
dataset_id, document_id = create_document("chunk_requires_content.txt")
@ -101,6 +177,156 @@ def test_chunks_add_requires_content(rest_client, create_document):
assert payload["message"] == "`content` is required", payload
@pytest.mark.p2
def test_chunk_add_keyword_question_and_tag_contract(rest_client, create_document):
add_cases = [
(
"important keywords",
[
({"content": "chunk test", "important_keywords": ["a", "b", "c"]}, 0, ""),
({"content": "chunk test", "important_keywords": [""]}, 0, ""),
({"content": "chunk test", "important_keywords": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
({"content": "chunk test", "important_keywords": ["a", "a"]}, 0, ""),
({"content": "chunk test", "important_keywords": "abc"}, 102, "`important_keywords` is required to be a list"),
({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"),
],
),
(
"questions",
[
({"content": "chunk test", "questions": ["a", "b", "c"]}, 0, ""),
({"content": "chunk test", "questions": [""]}, 0, ""),
({"content": "chunk test", "questions": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
({"content": "chunk test", "questions": ["a", "a"]}, 0, ""),
({"content": "chunk test", "questions": "abc"}, 102, "`questions` is required to be a list"),
({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"),
],
),
(
"tag_kwd",
[
({"content": "chunk test", "tag_kwd": ["tag1", "tag2"]}, 0, ""),
({"content": "chunk test", "tag_kwd": [""]}, 0, ""),
({"content": "chunk test", "tag_kwd": [1]}, 102, "`tag_kwd` must be a list of strings"),
({"content": "chunk test", "tag_kwd": ["tag", "tag"]}, 0, ""),
({"content": "chunk test", "tag_kwd": "abc"}, 102, "`tag_kwd` is required to be a list"),
({"content": "chunk test", "tag_kwd": 123}, 102, "`tag_kwd` is required to be a list"),
],
),
]
for group_index, (group_name, cases) in enumerate(add_cases):
dataset_id, document_id = create_document(f"chunk_add_contracts_{group_index}.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
for scenario_index, (payload, expected_code, expected_message) in enumerate(cases):
scenario_name = f"{group_name}-{scenario_index}"
before_payload = rest_client.get(base_path).json()
assert before_payload["code"] == 0, (scenario_name, before_payload)
before_total = before_payload["data"]["doc"]["chunk_count"]
res = rest_client.post(base_path, json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code == 0:
chunk = body["data"]["chunk"]
assert chunk["dataset_id"] == dataset_id, (scenario_name, body)
assert chunk["document_id"] == document_id, (scenario_name, body)
assert chunk["content"] == payload["content"], (scenario_name, body)
if "important_keywords" in payload:
assert chunk["important_keywords"] == payload["important_keywords"], (scenario_name, body)
if "questions" in payload:
assert chunk["questions"] == [str(q).strip() for q in payload["questions"] if str(q).strip()], (scenario_name, body)
if "tag_kwd" in payload:
assert chunk["tag_kwd"] == payload["tag_kwd"], (scenario_name, body)
after_payload = rest_client.get(base_path).json()
assert after_payload["code"] == 0, (scenario_name, after_payload)
assert after_payload["data"]["doc"]["chunk_count"] == before_total + 1, (scenario_name, after_payload)
else:
assert body["message"] == expected_message, (scenario_name, body)
@pytest.mark.p2
def test_chunk_add_invalid_dataset_and_document_contract(rest_client, create_document):
dataset_id, document_id = create_document("chunk_invalid_targets.txt")
invalid_dataset_res = rest_client.post(
f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks",
json={"content": "chunk test"},
)
assert invalid_dataset_res.status_code == 200
invalid_dataset_payload = invalid_dataset_res.json()
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload
invalid_document_res = rest_client.post(
f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks",
json={"content": "chunk test"},
)
assert invalid_document_res.status_code == 200
invalid_document_payload = invalid_document_res.json()
assert invalid_document_payload["code"] == 102, invalid_document_payload
assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload
@pytest.mark.p2
def test_chunk_add_repeated_and_deleted_document_contract(rest_client, create_document):
dataset_id, document_id = create_document("chunk_repeat_deleted.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
first_payload = rest_client.get(base_path).json()
assert first_payload["code"] == 0, first_payload
initial_count = first_payload["data"]["doc"]["chunk_count"]
first_add_res = rest_client.post(base_path, json={"content": "chunk test"})
second_add_res = rest_client.post(base_path, json={"content": "chunk test"})
first_add_payload = first_add_res.json()
second_add_payload = second_add_res.json()
assert first_add_payload["code"] == 0, first_add_payload
assert second_add_payload["code"] == 0, second_add_payload
assert first_add_payload["data"]["chunk"]["id"] == second_add_payload["data"]["chunk"]["id"], (first_add_payload, second_add_payload)
repeated_list_payload = rest_client.get(base_path).json()
assert repeated_list_payload["code"] == 0, repeated_list_payload
assert repeated_list_payload["data"]["doc"]["chunk_count"] == initial_count + 2, repeated_list_payload
assert repeated_list_payload["data"]["total"] == 1, repeated_list_payload
delete_document_res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]})
assert delete_document_res.status_code == 200
delete_document_payload = delete_document_res.json()
assert delete_document_payload["code"] == 0, delete_document_payload
add_after_delete_res = rest_client.post(base_path, json={"content": "chunk test"})
assert add_after_delete_res.status_code == 200
add_after_delete_payload = add_after_delete_res.json()
assert add_after_delete_payload["code"] == 102, add_after_delete_payload
assert add_after_delete_payload["message"] == f"You don't own the document {document_id}.", add_after_delete_payload
@pytest.mark.p2
@pytest.mark.parametrize("count", [20])
def test_chunk_concurrent_add_contract(rest_client, create_document, count):
dataset_id, document_id = create_document("chunk_concurrent_add.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
baseline_payload = rest_client.get(base_path).json()
assert baseline_payload["code"] == 0, baseline_payload
initial_count = baseline_payload["data"]["doc"]["chunk_count"]
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(
executor.map(
lambda index: rest_client.post(base_path, json={"content": f"chunk test {index}"}).json(),
range(count),
)
)
assert len(results) == count, results
assert all(result["code"] == 0 for result in results), results
final_payload = rest_client.get(base_path).json()
assert final_payload["code"] == 0, final_payload
assert final_payload["data"]["doc"]["chunk_count"] == initial_count + count, final_payload
@pytest.mark.p2
def test_chunks_list_empty_document(rest_client, create_document):
dataset_id, document_id = create_document("chunk_list_empty.txt")
@ -114,11 +340,478 @@ def test_chunks_list_empty_document(rest_client, create_document):
@pytest.mark.p2
def test_chunks_delete_partial_invalid(rest_client, create_document):
dataset_id, document_id = create_document("chunk_delete_partial.txt")
def test_chunk_delete_basic_contract(rest_client, create_document):
dataset_id, document_id = create_document("chunk_delete_basic.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
delete_res = rest_client.delete(base_path, json={"chunk_ids": ["invalid_chunk_id"]})
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 102, delete_payload
assert "expect 1" in delete_payload["message"], delete_payload
cases = [
("none payload", None, 0, "", 5),
("invalid only", {"chunk_ids": ["invalid_id"]}, 102, "rm_chunk deleted chunks 0, expect 1", 5),
("delete first", lambda ids: {"chunk_ids": ids[:1]}, 0, "", 4),
("delete generated", lambda ids: {"chunk_ids": ids}, 0, "", 1),
("empty ids", {"chunk_ids": []}, 0, "", 5),
]
for scenario_name, payload, expected_code, expected_message, remaining in cases:
_reset_chunk_batch(rest_client, base_path)
request_body = payload
generated_ids = rest_client.get(base_path).json()["data"]["chunks"][1:]
generated_ids = [chunk["id"] for chunk in generated_ids]
if callable(payload):
request_body = payload(generated_ids)
res = rest_client.delete(base_path, json=request_body)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_message:
assert body.get("message", "") == expected_message, (scenario_name, body)
list_payload = rest_client.get(base_path).json()
assert list_payload["code"] == 0, (scenario_name, list_payload)
assert len(list_payload["data"]["chunks"]) == remaining, (scenario_name, list_payload)
assert list_payload["data"]["total"] == remaining, (scenario_name, list_payload)
@pytest.mark.p2
def test_chunk_delete_partial_duplicate_repeat_and_invalid_target_contract(rest_client, create_document):
dataset_id, document_id = create_document("chunk_delete_detail.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
for scenario_name, payload_builder in (
("invalid first", lambda ids: {"chunk_ids": ["invalid_id"] + ids}),
("invalid middle", lambda ids: {"chunk_ids": ids[:1] + ["invalid_id"] + ids[1:]}),
("invalid last", lambda ids: {"chunk_ids": ids + ["invalid_id"]}),
):
_, generated_ids = _reset_chunk_batch(rest_client, base_path)
res = rest_client.delete(base_path, json=payload_builder(generated_ids))
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == 102, (scenario_name, body)
assert body["message"] == "rm_chunk deleted chunks 4, expect 5", (scenario_name, body)
list_payload = rest_client.get(base_path).json()
assert list_payload["code"] == 0, (scenario_name, list_payload)
assert list_payload["data"]["total"] == 1, (scenario_name, list_payload)
_, generated_ids = _reset_chunk_batch(rest_client, base_path)
duplicate_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids * 2})
assert duplicate_res.status_code == 200
duplicate_payload = duplicate_res.json()
assert duplicate_payload["code"] == 0, duplicate_payload
assert duplicate_payload["data"]["success_count"] == 4, duplicate_payload
assert len(duplicate_payload["data"]["errors"]) == 4, duplicate_payload
assert all(error.startswith("Duplicate chunk ids: ") for error in duplicate_payload["data"]["errors"]), duplicate_payload
duplicate_list_payload = rest_client.get(base_path).json()
assert duplicate_list_payload["code"] == 0, duplicate_list_payload
assert duplicate_list_payload["data"]["total"] == 1, duplicate_list_payload
_, generated_ids = _reset_chunk_batch(rest_client, base_path)
first_delete_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids})
assert first_delete_res.status_code == 200
assert first_delete_res.json()["code"] == 0
second_delete_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids})
assert second_delete_res.status_code == 200
second_delete_payload = second_delete_res.json()
assert second_delete_payload["code"] == 102, second_delete_payload
assert second_delete_payload["message"] == "rm_chunk deleted chunks 0, expect 4", second_delete_payload
invalid_dataset_res = rest_client.delete(
f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks",
json={"chunk_ids": ["chunk-id"]},
)
assert invalid_dataset_res.status_code == 200
invalid_dataset_payload = invalid_dataset_res.json()
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload
invalid_document_res = rest_client.delete(
f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks",
json={"chunk_ids": ["chunk-id"]},
)
assert invalid_document_res.status_code == 200
invalid_document_payload = invalid_document_res.json()
assert invalid_document_payload["code"] == 102, invalid_document_payload
assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload
@pytest.mark.p2
def test_chunk_delete_web_legacy_basic_variants(rest_client, create_document):
dataset_id, document_id = create_document("chunk_delete_web_legacy_again.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
cases = [
("web invalid id", {"chunk_ids": ["invalid_id"]}, 102, 5),
("web delete first", lambda ids: {"chunk_ids": ids[:1]}, 0, 4),
("web delete generated", lambda ids: {"chunk_ids": ids}, 0, 1),
("web empty ids", {"chunk_ids": []}, 0, 5),
]
for scenario_name, payload, expected_code, remaining in cases:
_, generated_ids = _reset_chunk_batch(rest_client, base_path)
request_body = payload(generated_ids) if callable(payload) else payload
res = rest_client.delete(base_path, json=request_body)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
list_payload = rest_client.get(base_path).json()
assert list_payload["code"] == 0, (scenario_name, list_payload)
assert list_payload["data"]["total"] == remaining, (scenario_name, list_payload)
@pytest.mark.p2
def test_chunk_delete_concurrent_and_bulk_contract(rest_client, create_document):
dataset_id, document_id = create_document("chunk_delete_bulk_contract.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True})
for index in range(12):
payload = rest_client.post(base_path, json={"content": f"chunk test {index}"}).json()
assert payload["code"] == 0, payload
ids_payload = rest_client.get(base_path).json()
assert ids_payload["code"] == 0, ids_payload
chunk_ids = [chunk["id"] for chunk in ids_payload["data"]["chunks"]]
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(lambda chunk_id: rest_client.delete(base_path, json={"chunk_ids": [chunk_id]}).json(), chunk_ids))
assert len(results) == len(chunk_ids), results
assert all(result["code"] == 0 for result in results), results
final_payload = rest_client.get(base_path).json()
assert final_payload["code"] == 0, final_payload
assert final_payload["data"]["total"] == 0, final_payload
rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True})
for index in range(40):
payload = rest_client.post(base_path, json={"content": f"bulk chunk {index}"}).json()
assert payload["code"] == 0, payload
bulk_ids_payload = rest_client.get(base_path, params={"page_size": 200}).json()
assert bulk_ids_payload["code"] == 0, bulk_ids_payload
bulk_ids = [chunk["id"] for chunk in bulk_ids_payload["data"]["chunks"]]
bulk_res = rest_client.delete(base_path, json={"chunk_ids": bulk_ids})
assert bulk_res.status_code == 200
bulk_payload = bulk_res.json()
assert bulk_payload["code"] == 0, bulk_payload
@pytest.mark.p2
def test_chunk_list_default_get_id_and_invalid_target_contract(rest_client, create_document):
dataset_id, document_id = create_document("chunk_list_core.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
baseline_id, generated_ids = _reset_chunk_batch(rest_client, base_path)
default_res = rest_client.get(base_path)
assert default_res.status_code == 200
default_payload = default_res.json()
assert default_payload["code"] == 0, default_payload
assert default_payload["data"]["total"] == 5, default_payload
assert len(default_payload["data"]["chunks"]) == 5, default_payload
get_res = rest_client.get(f"{base_path}/{generated_ids[0]}")
assert get_res.status_code == 200
get_payload = get_res.json()
assert get_payload["code"] == 0, get_payload
assert get_payload["data"]["id"] == generated_ids[0], get_payload
assert get_payload["data"]["doc_id"] == document_id, get_payload
invalid_get_res = rest_client.get(f"{base_path}/unknown")
assert invalid_get_res.status_code == 200
invalid_get_payload = invalid_get_res.json()
assert invalid_get_payload["code"] == 102, invalid_get_payload
assert invalid_get_payload["message"] == "Chunk not found!", invalid_get_payload
id_cases = [
("id none", {"id": None}, 0, 5, None),
("id empty", {"id": ""}, 0, 5, None),
("id valid", {"id": generated_ids[0]}, 0, 1, generated_ids[0]),
("id invalid", {"id": "unknown"}, 102, None, None),
]
for scenario_name, params, expected_code, expected_total, expected_id in id_cases:
res = rest_client.get(base_path, params=params)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)
if expected_code == 0:
assert payload["data"]["total"] == expected_total, (scenario_name, payload)
if expected_id is not None:
assert payload["data"]["chunks"][0]["id"] == expected_id, (scenario_name, payload)
else:
assert payload["message"] == f"Chunk not found: {dataset_id}/unknown", (scenario_name, payload)
invalid_dataset_res = rest_client.get(f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks")
assert invalid_dataset_res.status_code == 200
invalid_dataset_payload = invalid_dataset_res.json()
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload
invalid_document_res = rest_client.get(f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks")
assert invalid_document_res.status_code == 200
invalid_document_payload = invalid_document_res.json()
assert invalid_document_payload["code"] == 102, invalid_document_payload
assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload
@pytest.mark.p2
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity")
def test_chunk_list_keyword_and_invalid_param_contract(rest_client, create_document):
dataset_id, document_id = create_document("chunk_list_keywords.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
_reset_chunk_batch(rest_client, base_path)
cases = [
("keywords none", {"keywords": None}, 5),
("keywords empty", {"keywords": ""}, 5),
("keywords exact one", {"keywords": "1"}, 1),
("keywords chunk", {"keywords": "chunk"}, 4),
("keywords ragflow", {"keywords": "ragflow"}, 1),
("keywords unknown", {"keywords": "unknown"}, 0),
("invalid params ignored", {"a": "b"}, 5),
]
for scenario_name, params, expected_total in cases:
res = rest_client.get(base_path, params=params)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 0, (scenario_name, payload)
assert payload["data"]["total"] == expected_total, (scenario_name, payload)
assert len(payload["data"]["chunks"]) == expected_total, (scenario_name, payload)
@pytest.mark.p2
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity")
def test_chunk_list_page_and_page_size_contract(rest_client, create_document):
dataset_id, document_id = create_document("chunk_list_paging.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
_reset_chunk_batch(rest_client, base_path)
cases = [
("page none", {"page": None, "page_size": 2}, 0, 2, ""),
("page zero", {"page": 0, "page_size": 2}, 100, None, "ValueError('Search does not support negative slicing.')"),
("page two", {"page": 2, "page_size": 2}, 0, 2, ""),
("page three", {"page": 3, "page_size": 2}, 0, 1, ""),
("page string", {"page": "3", "page_size": 2}, 0, 1, ""),
("page negative", {"page": -1, "page_size": 2}, 100, None, "ValueError('Search does not support negative slicing.')"),
("page alpha", {"page": "a", "page_size": 2}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
("page_size none", {"page_size": None}, 0, 5, ""),
("page_size zero", {"page_size": 0}, 0, 5, ""),
("page_size one", {"page_size": 1}, 0, 1, ""),
("page_size six", {"page_size": 6}, 0, 5, ""),
("page_size string", {"page_size": "1"}, 0, 1, ""),
("page_size negative", {"page_size": -1}, 0, 5, ""),
("page_size alpha", {"page_size": "a"}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
]
for scenario_name, params, expected_code, expected_total, expected_message in cases:
res = rest_client.get(base_path, params=params)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)
if expected_code == 0:
assert payload["data"]["total"] == 5, (scenario_name, payload)
assert len(payload["data"]["chunks"]) == expected_total, (scenario_name, payload)
else:
assert expected_message in payload["message"], (scenario_name, payload)
@pytest.mark.p2
def test_chunk_list_concurrent_contract(rest_client, create_document):
dataset_id, document_id = create_document("chunk_list_concurrent.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
_reset_chunk_batch(rest_client, base_path)
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(lambda _: rest_client.get(base_path).json(), range(20)))
assert len(results) == 20, results
assert all(result["code"] == 0 for result in results), results
assert all(result["data"]["total"] == 5 for result in results), results
def _create_chunk_for_update(rest_client, create_document, file_name):
dataset_id, document_id = create_document(file_name)
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
add_res = rest_client.post(base_path, json={"content": "chunk update test"})
assert add_res.status_code == 200, add_res.text
add_payload = add_res.json()
assert add_payload["code"] == 0, add_payload
chunk_id = add_payload["data"]["chunk"]["id"]
return dataset_id, document_id, chunk_id, base_path
@pytest.mark.p2
def test_chunk_update_requires_auth(rest_client, create_document):
_, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_auth.txt")
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.patch(f"{base_path}/{chunk_id}", json={"content": "updated"})
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
@pytest.mark.p2
def test_chunk_update_content_and_available_contract(rest_client, create_document):
content_cases = [
("content none", {"content": None}, 0, ""),
("content empty", {"content": ""}, 102, "`content` is required"),
("content text", {"content": "update chunk"}, 0, ""),
("content spaces", {"content": " "}, 102, "`content` is required"),
("content punctuation", {"content": "\n!?。;!?\"'"}, 0, ""),
]
for scenario_name, payload, expected_code, expected_message in content_cases:
_, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, f"{scenario_name}.txt")
res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code != 0:
assert body["message"] == expected_message, (scenario_name, body)
available_cases = [
("available true", {"available": True}, 0, ""),
("available true str", {"available": "True"}, 100, "invalid literal for int()"),
("available one", {"available": 1}, 0, ""),
("available false", {"available": False}, 0, ""),
("available false str", {"available": "False"}, 100, "invalid literal for int()"),
("available zero", {"available": 0}, 0, ""),
]
for scenario_name, payload, expected_code, expected_message in available_cases:
_, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, f"{scenario_name}.txt")
res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code != 0:
assert expected_message in body["message"], (scenario_name, body)
@pytest.mark.p2
def test_chunk_update_keywords_questions_and_tag_contract(rest_client, create_document):
_, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_fields.txt")
cases = [
("important keywords", {"important_keywords": ["a", "b", "c"]}, 0, ""),
("important keywords empty", {"important_keywords": [""]}, 0, ""),
("important keywords int", {"important_keywords": [1]}, 100, "TypeError"),
("important keywords dup", {"important_keywords": ["a", "a"]}, 0, ""),
("important keywords str", {"important_keywords": "abc"}, 102, "`important_keywords` should be a list"),
("important keywords number", {"important_keywords": 123}, 102, "`important_keywords` should be a list"),
("questions", {"questions": ["a", "b", "c"]}, 0, ""),
("questions empty", {"questions": [""]}, 0, ""),
("questions int", {"questions": [1]}, 100, "TypeError"),
("questions dup", {"questions": ["a", "a"]}, 0, ""),
("questions str", {"questions": "abc"}, 102, "`questions` should be a list"),
("questions number", {"questions": 123}, 102, "`questions` should be a list"),
("tag kwd", {"tag_kwd": ["tag1", "tag2"]}, 0, ""),
("tag kwd empty", {"tag_kwd": [""]}, 0, ""),
("tag kwd int in list", {"tag_kwd": [1]}, 102, "`tag_kwd` must be a list of strings"),
("tag kwd dup", {"tag_kwd": ["tag", "tag"]}, 0, ""),
("tag kwd str", {"tag_kwd": "tag"}, 102, "`tag_kwd` should be a list"),
("tag kwd number", {"tag_kwd": 123}, 102, "`tag_kwd` should be a list"),
]
for scenario_name, payload, expected_code, expected_message in cases:
res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code != 0:
assert expected_message in body["message"], (scenario_name, body)
@pytest.mark.p2
def test_chunk_update_invalid_target_and_param_contract(rest_client, create_document):
dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_invalid_targets.txt")
invalid_dataset_res = rest_client.patch(
f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks/{chunk_id}",
json={"content": "updated"},
)
assert invalid_dataset_res.status_code == 200
invalid_dataset_payload = invalid_dataset_res.json()
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
assert invalid_dataset_payload["message"] in {
f"You don't own the dataset {INVALID_ID_32}.",
f"Can't find this chunk {chunk_id}",
}, invalid_dataset_payload
invalid_document_res = rest_client.patch(
f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks/{chunk_id}",
json={"content": "updated"},
)
assert invalid_document_res.status_code == 200
invalid_document_payload = invalid_document_res.json()
assert invalid_document_payload["code"] == 102, invalid_document_payload
assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload
invalid_chunk_res = rest_client.patch(
f"{base_path}/{INVALID_ID_32}",
json={"content": "updated"},
)
assert invalid_chunk_res.status_code == 200
invalid_chunk_payload = invalid_chunk_res.json()
assert invalid_chunk_payload["code"] == 102, invalid_chunk_payload
assert invalid_chunk_payload["message"] == f"Can't find this chunk {INVALID_ID_32}", invalid_chunk_payload
for scenario_name, payload in (
("unknown key", {"unknown_key": "unknown_value"}),
("empty payload", {}),
):
res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == 0, (scenario_name, body)
@pytest.mark.p2
def test_chunk_update_repeated_concurrent_and_deleted_document_contract(rest_client, create_document):
dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update(
rest_client, create_document, "chunk_update_repeated_concurrent_deleted.txt"
)
first_res = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "chunk test 1"})
assert first_res.status_code == 200
assert first_res.json()["code"] == 0
second_res = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "chunk test 2"})
assert second_res.status_code == 200
assert second_res.json()["code"] == 0
get_after_repeat = rest_client.get(f"{base_path}/{chunk_id}")
assert get_after_repeat.status_code == 200
get_after_repeat_payload = get_after_repeat.json()
assert get_after_repeat_payload["code"] == 0, get_after_repeat_payload
assert get_after_repeat_payload["data"]["content_with_weight"] == "chunk test 2", get_after_repeat_payload
chunk_ids = [chunk_id]
for index in range(3):
add_res = rest_client.post(base_path, json={"content": f"concurrent update {index}"})
assert add_res.status_code == 200, add_res.text
add_payload = add_res.json()
assert add_payload["code"] == 0, add_payload
chunk_ids.append(add_payload["data"]["chunk"]["id"])
with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for index in range(20):
target_id = chunk_ids[index % len(chunk_ids)]
futures.append(
executor.submit(
lambda cid, i: rest_client.patch(
f"{base_path}/{cid}",
json={"content": f"update chunk test {i}"},
).json(),
target_id,
index,
)
)
results = [future.result() for future in futures]
assert len(results) == 20, results
assert all(item["code"] == 0 for item in results), results
delete_document_res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]})
assert delete_document_res.status_code == 200
assert delete_document_res.json()["code"] == 0
update_after_delete = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "after delete"})
assert update_after_delete.status_code == 200
update_after_delete_payload = update_after_delete.json()
assert update_after_delete_payload["code"] == 102, update_after_delete_payload
assert update_after_delete_payload["message"] in {
f"You don't own the document {document_id}.",
f"Can't find this chunk {chunk_id}",
}, update_after_delete_payload

View File

@ -0,0 +1,425 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import inspect
import sys
from copy import deepcopy
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _DummyKB:
def __init__(self, tenant_id="tenant-1", embd_id="embd-1", tenant_embd_id=1):
self.tenant_id = tenant_id
self.embd_id = embd_id
self.tenant_embd_id = tenant_embd_id
class _DummyRetriever:
async def retrieval(self, *_args, **_kwargs):
return {
"chunks": [
{"doc_id": "doc-1", "content_with_weight": "chunk-content", "similarity": 0.8, "docnm_kwd": "doc-title", "vector": [0.1]}
]
}
def retrieval_by_children(self, chunks, _tenant_ids):
return chunks
def _run(coro):
return asyncio.run(coro)
def _load_dify_retrieval_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
deepdoc_pkg = ModuleType("deepdoc")
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
deepdoc_parser_pkg.__path__ = []
class _StubPdfParser:
pass
class _StubExcelParser:
pass
class _StubDocxParser:
pass
deepdoc_parser_pkg.PdfParser = _StubPdfParser
deepdoc_parser_pkg.ExcelParser = _StubExcelParser
deepdoc_parser_pkg.DocxParser = _StubDocxParser
deepdoc_pkg.parser = deepdoc_parser_pkg
monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg)
monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg)
deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser")
deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser
monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module)
deepdoc_parser_utils = ModuleType("deepdoc.parser.utils")
deepdoc_parser_utils.get_text = lambda *_args, **_kwargs: ""
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
class _MockModelConfig:
def __init__(self, tenant_id, model_name):
self.tenant_id = tenant_id
self.llm_name = model_name
self.llm_factory = "Builtin"
self.api_key = "fake-api-key"
self.api_base = "https://api.example.com"
self.model_type = "chat"
self.max_tokens = 8192
self.used_tokens = 0
self.status = 1
self.id = 1
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"api_key": self.api_key,
"api_base": self.api_base,
"model_type": self.model_type,
"max_tokens": self.max_tokens,
"used_tokens": self.used_tokens,
"status": self.status,
"id": self.id,
}
class _StubTenantService:
@staticmethod
def get_by_id(tenant_id):
return True, SimpleNamespace(
id=tenant_id,
llm_id="chat-model",
embd_id="embd-model",
asr_id="asr-model",
img2txt_id="img2txt-model",
rerank_id="rerank-model",
tts_id="tts-model",
)
class _StubTenantLLMService:
@staticmethod
def get_api_key(tenant_id, model_name):
return _MockModelConfig(tenant_id, model_name)
@staticmethod
def split_model_name_and_factory(model_name):
if "@" in model_name:
parts = model_name.split("@")
return parts[0], parts[1]
return model_name, None
tenant_llm_service_mod.TenantService = _StubTenantService
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
class _StubLLMFactoriesService:
pass
tenant_llm_service_mod.LLMFactoriesService = _StubLLMFactoriesService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
llm_service_mod = ModuleType("api.db.services.llm_service")
class _StubLLM:
def __init__(self, llm_name):
self.llm_name = llm_name
self.is_tools = False
class _StubLLMBundle:
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.model_config = model_config
self.lang = lang
def encode(self, texts: list):
import numpy as np
return [np.array([0.1, 0.2, 0.3]) for _ in texts], len(texts) * 10
llm_service_mod.LLMService = SimpleNamespace(query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else [])
llm_service_mod.LLMBundle = _StubLLMBundle
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
class _MockModelConfig2:
def __init__(self, tenant_id, model_name):
self.tenant_id = tenant_id
self.llm_name = model_name
self.llm_factory = "Builtin"
self.api_key = "fake-api-key"
self.api_base = "https://api.example.com"
self.model_type = "chat"
self.max_tokens = 8192
self.used_tokens = 0
self.status = 1
self.id = 1
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"api_key": self.api_key,
"api_base": self.api_base,
"model_type": self.model_type,
"max_tokens": self.max_tokens,
"used_tokens": self.used_tokens,
"status": self.status,
"id": self.id,
}
def _get_model_config_by_id(tenant_model_id: int, allowed_tenant_ids=None, requester_tenant_id=None) -> dict:
mock_tenant_id = "tenant-1"
if allowed_tenant_ids is not None:
if isinstance(allowed_tenant_ids, str):
allowed_tenant_ids = {allowed_tenant_ids}
else:
allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id}
if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id:
raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized")
return _MockModelConfig2(mock_tenant_id, "model-1").to_dict()
def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
if not model_name:
raise Exception("Model Name is required")
return _MockModelConfig2(tenant_id, model_name).to_dict()
def _get_tenant_default_model_by_type(tenant_id: str, model_type):
return _MockModelConfig2(tenant_id, "chat-model").to_dict()
tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id
tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name
tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
module_name = "test_dify_retrieval_routes_unit_module"
module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
def _set_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload)))
@pytest.mark.p2
def test_retrieval_success_with_metadata_and_kg(monkeypatch):
module = _load_dify_retrieval_module(monkeypatch)
_set_request_json(
monkeypatch,
module,
{
"knowledge_id": "kb-1",
"query": "hello",
"use_kg": True,
"retrieval_setting": {"score_threshold": 0.1, "top_k": 3},
"metadata_condition": {"conditions": [{"name": "author", "comparison_operator": "is", "value": "alice"}], "logic": "and"},
},
)
monkeypatch.setattr(module, "jsonify", lambda payload: payload)
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [{"doc_id": "doc-1"}])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", []))
monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: [])
retriever = _DummyRetriever()
monkeypatch.setattr(module.settings, "retriever", retriever)
class _DummyKgRetriever:
async def retrieval(self, *_args, **_kwargs):
return {
"doc_id": "doc-2",
"content_with_weight": "kg-content",
"similarity": 0.9,
"docnm_kwd": "kg-title",
}
monkeypatch.setattr(module.settings, "kg_retriever", _DummyKgRetriever())
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda doc_id: (True, SimpleNamespace(meta_fields={"origin": f"meta-{doc_id}"})))
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert "records" in res, res
assert len(res["records"]) == 2, res
top = res["records"][0]
assert top["title"] == "kg-title", res
assert top["metadata"]["doc_id"] == "doc-2", res
assert "score" in top, res
@pytest.mark.p2
def test_retrieval_kb_not_found(monkeypatch):
module = _load_dify_retrieval_module(monkeypatch)
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-missing", "query": "hello"})
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res["code"] == module.RetCode.NOT_FOUND, res
assert "Knowledgebase not found" in res["message"], res
@pytest.mark.p2
def test_retrieval_not_found_exception_mapping(monkeypatch):
module = _load_dify_retrieval_module(monkeypatch)
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
class _BrokenRetriever:
async def retrieval(self, *_args, **_kwargs):
raise RuntimeError("chunk_not_found_error")
monkeypatch.setattr(module.settings, "retriever", _BrokenRetriever())
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res["code"] == module.RetCode.NOT_FOUND, res
assert "No chunk found" in res["message"], res
@pytest.mark.p2
def test_retrieval_generic_exception_mapping(monkeypatch):
module = _load_dify_retrieval_module(monkeypatch)
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
class _BrokenRetriever:
async def retrieval(self, *_args, **_kwargs):
raise RuntimeError("boom")
monkeypatch.setattr(module.settings, "retriever", _BrokenRetriever())
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res["code"] == module.RetCode.SERVER_ERROR, res
assert "boom" in res["message"], res
@pytest.mark.p2
def test_read_retrieval_request_from_get_args(monkeypatch):
module = _load_dify_retrieval_module(monkeypatch)
monkeypatch.setattr(
module,
"request",
SimpleNamespace(
method="GET",
args={
"knowledge_id": "kb-1",
"query": "hello",
"use_kg": "true",
"top_k": "12",
"score_threshold": "0.66",
},
),
)
req = _run(module._read_retrieval_request())
assert req["knowledge_id"] == "kb-1", req
assert req["query"] == "hello", req
assert req["use_kg"] is True, req
assert req["retrieval_setting"]["top_k"] == 12, req
assert req["retrieval_setting"]["score_threshold"] == 0.66, req
@pytest.mark.p2
def test_read_retrieval_request_from_post_json(monkeypatch):
module = _load_dify_retrieval_module(monkeypatch)
payload = {"knowledge_id": "kb-1", "query": "hello"}
monkeypatch.setattr(module, "request", SimpleNamespace(method="POST", args={}))
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
req = _run(module._read_retrieval_request())
assert req == payload, req
@pytest.mark.p2
def test_retrieval_argument_error_messages(monkeypatch):
module = _load_dify_retrieval_module(monkeypatch)
_set_request_json(
monkeypatch,
module,
{
"knowledge_id": "kb-1",
"query": "hello",
"retrieval_setting": {"top_k": "not-int", "score_threshold": "not-float"},
},
)
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
assert "invalid or malformed arguments:" in res["message"], res
_set_request_json(monkeypatch, module, {})
res_missing = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res_missing["code"] == module.RetCode.ARGUMENT_ERROR, res_missing
assert "required arguments are missing:" in res_missing["message"], res_missing
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1"})
res_missing_query = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res_missing_query["code"] == module.RetCode.ARGUMENT_ERROR, res_missing_query
assert "query" in res_missing_query["message"], res_missing_query
_set_request_json(
monkeypatch,
module,
{"knowledge_id": "kb-1", "query": "hello", "retrieval_setting": "bad-type"},
)
res_wrong_type = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res_wrong_type["code"] == module.RetCode.ARGUMENT_ERROR, res_wrong_type
assert "retrieval_setting must be an object" in res_wrong_type["message"], res_wrong_type

View File

@ -14,7 +14,12 @@
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
import requests
from test.testcases.configs import HOST_ADDRESS, INVALID_API_TOKEN, VERSION
from test.testcases.restful_api.helpers.client import RestClient
from test.testcases.utils import wait_for
@pytest.mark.p1
@ -107,3 +112,264 @@ def test_retrieval_compatibility_requires_auth(rest_client_noauth):
# token_required preserves legacy payload code/message while returning HTTP 401.
assert payload["code"] == 0, payload
assert payload["message"] == "`Authorization` can't be empty", payload
@wait_for(20, 1, "Retrieval indexing timeout in RESTful batch 10 tests")
def _retrieval_has_question(rest_client, dataset_id, question):
res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]})
if res.status_code != 200:
return False
payload = res.json()
if payload["code"] != 0:
return False
return len(payload["data"]["chunks"]) > 0
@wait_for(20, 1, "Retrieval indexing timeout waiting for chunk presence in RESTful batch 10 tests")
def _retrieval_has_chunks(rest_client, dataset_id, question, chunk_ids):
res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]})
if res.status_code != 200:
return False
payload = res.json()
if payload["code"] != 0:
return False
retrieved_ids = {chunk["id"] for chunk in payload["data"]["chunks"]}
expected_ids = set(chunk_ids)
return expected_ids.issubset(retrieved_ids)
@wait_for(20, 1, "Retrieval indexing timeout waiting for chunk deletion in RESTful batch 10 tests")
def _retrieval_lacks_chunks(rest_client, dataset_id, question, chunk_ids):
res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]})
if res.status_code != 200:
return False
payload = res.json()
if payload["code"] != 0:
return False
retrieved_ids = {chunk["id"] for chunk in payload["data"]["chunks"]}
expected_ids = set(chunk_ids)
return expected_ids.isdisjoint(retrieved_ids)
@pytest.mark.p2
def test_retrieval_requires_auth_contract(ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
for scenario_name, token, expected_code, expected_message in (
("missing token", None, 0, "`Authorization` can't be empty"),
("invalid token", INVALID_API_TOKEN, 109, "Authentication error: API key is invalid!"),
):
client = RestClient(token=token)
res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": [dataset_id]})
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)
assert payload["message"] == expected_message, (scenario_name, payload)
@pytest.mark.p2
def test_retrieval_page_and_page_size_contract(rest_client, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
cases = [
("page none", {"question": "chunk", "dataset_ids": [dataset_id], "page": None, "page_size": 2}, 100, "TypeError"),
("page zero", {"question": "chunk", "dataset_ids": [dataset_id], "page": 0, "page_size": 2}, 0, ""),
("page two", {"question": "chunk", "dataset_ids": [dataset_id], "page": 2, "page_size": 2}, 0, ""),
("page three", {"question": "chunk", "dataset_ids": [dataset_id], "page": 3, "page_size": 2}, 0, ""),
("page str", {"question": "chunk", "dataset_ids": [dataset_id], "page": "3", "page_size": 2}, 0, ""),
("page negative", {"question": "chunk", "dataset_ids": [dataset_id], "page": -1, "page_size": 2}, 0, ""),
("page alpha", {"question": "chunk", "dataset_ids": [dataset_id], "page": "a", "page_size": 2}, 100, "invalid literal for int()"),
("page_size none", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": None}, 100, "TypeError"),
("page_size one", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": 1}, 0, ""),
("page_size five", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": 5}, 0, ""),
("page_size str", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": "1"}, 0, ""),
("page_size alpha", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": "a"}, 100, "invalid literal for int()"),
]
for scenario_name, payload, expected_code, expected_message in cases:
res = rest_client.post("/retrieval", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code != 0:
assert expected_message in body["message"], (scenario_name, body)
@pytest.mark.p2
def test_retrieval_highlight_keyword_and_invalid_params_contract(rest_client, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
highlight_cases = [
("highlight true", True, True),
("highlight true str", "True", True),
("highlight false", False, False),
("highlight false str", "False", False),
("highlight none", None, False),
]
for scenario_name, highlight_value, expect_highlight in highlight_cases:
res = rest_client.post(
"/retrieval",
json={"question": "chunk", "dataset_ids": [dataset_id], "highlight": highlight_value},
)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == 0, (scenario_name, body)
for chunk in body["data"]["chunks"]:
if expect_highlight:
assert "highlight" in chunk, (scenario_name, body)
else:
assert "highlight" not in chunk, (scenario_name, body)
invalid_highlight = rest_client.post(
"/retrieval",
json={"question": "chunk", "dataset_ids": [dataset_id], "highlight": "not_bool"},
)
assert invalid_highlight.status_code == 200
invalid_highlight_payload = invalid_highlight.json()
assert invalid_highlight_payload["code"] == 102, invalid_highlight_payload
assert invalid_highlight_payload["message"] == "`highlight` should be a boolean", invalid_highlight_payload
for scenario_name, keyword_value in (
("keyword true", True),
("keyword true str", "True"),
("keyword false", False),
("keyword false str", "False"),
("keyword none", None),
):
keyword_res = rest_client.post(
"/retrieval",
json={"question": "chunk test", "dataset_ids": [dataset_id], "keyword": keyword_value},
)
assert keyword_res.status_code == 200, (scenario_name, keyword_res.text)
keyword_payload = keyword_res.json()
assert keyword_payload["code"] == 0, (scenario_name, keyword_payload)
assert isinstance(keyword_payload["data"]["chunks"], list), (scenario_name, keyword_payload)
invalid_params_res = rest_client.post(
"/retrieval",
json={"question": "chunk", "dataset_ids": [dataset_id], "a": "b"},
)
assert invalid_params_res.status_code == 200
invalid_params_payload = invalid_params_res.json()
assert invalid_params_payload["code"] == 0, invalid_params_payload
@pytest.mark.p2
def test_retrieval_vector_similarity_and_top_k_contract(rest_client, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
cases = [
("vector 0", {"vector_similarity_weight": 0}, 0, ""),
("vector 0.5", {"vector_similarity_weight": 0.5}, 0, ""),
("vector 10", {"vector_similarity_weight": 10}, 0, ""),
("vector alpha", {"vector_similarity_weight": "a"}, 100, "could not convert string to float"),
("top_k 10", {"top_k": 10}, 0, ""),
("top_k 1", {"top_k": 1}, 0, ""),
("top_k -1", {"top_k": -1}, 102, "`top_k` must be greater than 0"),
("top_k alpha", {"top_k": "a"}, 100, "invalid literal for int()"),
]
for scenario_name, updates, expected_code, expected_message in cases:
payload = {"question": "chunk", "dataset_ids": [dataset_id]}
payload.update(updates)
res = rest_client.post("/retrieval", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code != 0:
assert expected_message in body["message"], (scenario_name, body)
@pytest.mark.p2
def test_retrieval_rerank_unknown_contract(rest_client, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
res = rest_client.post(
"/retrieval",
json={"question": "chunk", "dataset_ids": [dataset_id], "rerank_id": "unknown"},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] != 0, payload
assert payload["message"], payload
@pytest.mark.p2
def test_retrieval_concurrent_contract(rest_client, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
payload = {"question": "chunk", "dataset_ids": [dataset_id]}
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(lambda _: rest_client.post("/retrieval", json=payload).json(), range(20)))
assert len(results) == 20, results
assert all(result["code"] == 0 for result in results), results
@pytest.mark.p2
def test_deleted_chunk_not_in_retrieval_contract(rest_client, create_document):
dataset_id, document_id = create_document("retrieval_deleted_chunk.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
content = "UNIQUE_TEST_CONTENT_12520_REST"
add_res = rest_client.post(base_path, json={"content": content})
assert add_res.status_code == 200
add_payload = add_res.json()
assert add_payload["code"] == 0, add_payload
chunk_id = add_payload["data"]["chunk"]["id"]
_retrieval_has_chunks(rest_client, dataset_id, content, [chunk_id])
delete_res = rest_client.delete(base_path, json={"chunk_ids": [chunk_id]})
assert delete_res.status_code == 200
assert delete_res.json()["code"] == 0
_retrieval_lacks_chunks(rest_client, dataset_id, content, [chunk_id])
@pytest.mark.p2
def test_deleted_chunks_batch_not_in_retrieval_contract(rest_client, create_document):
dataset_id, document_id = create_document("retrieval_deleted_chunks_batch.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
chunk_ids = []
for index in range(3):
content = f"BATCH_DELETE_TEST_CHUNK_{index}_REST_12520"
add_res = rest_client.post(base_path, json={"content": content})
assert add_res.status_code == 200
add_payload = add_res.json()
assert add_payload["code"] == 0, add_payload
chunk_ids.append(add_payload["data"]["chunk"]["id"])
_retrieval_has_chunks(rest_client, dataset_id, "BATCH_DELETE_TEST_CHUNK", chunk_ids)
delete_res = rest_client.delete(base_path, json={"chunk_ids": chunk_ids})
assert delete_res.status_code == 200
assert delete_res.json()["code"] == 0
_retrieval_lacks_chunks(rest_client, dataset_id, "BATCH_DELETE_TEST_CHUNK", chunk_ids)
@pytest.mark.p2
def test_related_questions_contract(auth, rest_client, rest_client_noauth):
tokens_res = requests.get(
f"{HOST_ADDRESS}/api/{VERSION}/system/tokens",
headers={"Authorization": auth},
timeout=30,
)
assert tokens_res.status_code == 200, tokens_res.text
tokens_payload = tokens_res.json()
assert tokens_payload["code"] == 0, tokens_payload
assert tokens_payload["data"], tokens_payload
beta_token = tokens_payload["data"][0]["beta"]
success_client = RestClient(token=beta_token)
success_res = success_client.post("/searchbots/related_questions", json={"question": "ragflow", "industry": "search"})
assert success_res.status_code == 200
success_payload = success_res.json()
assert success_payload["code"] == 0, success_payload
assert isinstance(success_payload["data"], list), success_payload
missing_res = rest_client.post("/searchbots/related_questions", json={"industry": "search"})
assert missing_res.status_code == 200
missing_payload = missing_res.json()
assert missing_payload["code"] == 101, missing_payload
assert "question" in missing_payload["message"], missing_payload
invalid_auth_res = rest_client_noauth.post(
"/searchbots/related_questions",
json={"question": "ragflow", "industry": "search"},
headers={"Authorization": "invalid"},
)
assert invalid_auth_res.status_code == 200
invalid_auth_payload = invalid_auth_res.json()
assert invalid_auth_payload["code"] == 102, invalid_auth_payload
assert "Authorization is not valid!" in invalid_auth_payload["message"], invalid_auth_payload