From aea90f4e3974a0e506f83e6da96c62bbaa734fe7 Mon Sep 17 00:00:00 2001 From: Idriss Sbaaoui <112825897+6ba3i@users.noreply.github.com> Date: Wed, 20 May 2026 14:56:55 +0800 Subject: [PATCH] Feat: add new tests and tescases for restful api suite (#15038) ### What problem does this PR solve? extend restful api suite ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Other (please describe): test --- test/testcases/restful_api/test_chunks.py | 707 +++++++++++++++++- .../test_dify_retrieval_routes_unit.py | 425 +++++++++++ test/testcases/restful_api/test_retrieval.py | 266 +++++++ 3 files changed, 1391 insertions(+), 7 deletions(-) create mode 100644 test/testcases/restful_api/test_dify_retrieval_routes_unit.py diff --git a/test/testcases/restful_api/test_chunks.py b/test/testcases/restful_api/test_chunks.py index 42009a2af..e2ed7b48c 100644 --- a/test/testcases/restful_api/test_chunks.py +++ b/test/testcases/restful_api/test_chunks.py @@ -14,7 +14,12 @@ # limitations under the License. # +from concurrent.futures import ThreadPoolExecutor +import os import pytest +from test.testcases.configs import INVALID_API_TOKEN, INVALID_ID_32 +from test.testcases.restful_api.helpers.client import RestClient +from test.testcases.utils import wait_for def _assert_created_chunk_id(payload): @@ -25,6 +30,41 @@ def _assert_created_chunk_id(payload): return chunk_id +@wait_for(10, 1, "Chunk indexing timeout in RESTful batch 09 tests") +def _chunk_count(rest_client, base_path, expected_total): + res = rest_client.get(base_path) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + return payload["data"]["total"] == expected_total and len(payload["data"]["chunks"]) == min(expected_total, 30) + + +def _reset_chunk_batch(rest_client, base_path, count=4): + cleanup_res = rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True}) + assert cleanup_res.status_code == 200, cleanup_res.text + cleanup_payload = cleanup_res.json() + assert cleanup_payload["code"] == 0, cleanup_payload + + baseline_res = rest_client.post(base_path, json={"content": "ragflow test upload"}) + assert baseline_res.status_code == 200, baseline_res.text + baseline_payload = baseline_res.json() + assert baseline_payload["code"] == 0, baseline_payload + baseline_id = _assert_created_chunk_id(baseline_payload) + + chunk_ids = [] + for index in range(count): + res = rest_client.post(base_path, json={"content": f"chunk test {index}"}) + assert res.status_code == 200, (index, res.text) + payload = res.json() + assert payload["code"] == 0, (index, payload) + chunk_ids.append(_assert_created_chunk_id(payload)) + + _chunk_count(rest_client, base_path, count + 1) + return baseline_id, chunk_ids + + @pytest.mark.p1 def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document): dataset_id, document_id = create_document("chunk_cycle.txt") @@ -88,6 +128,42 @@ def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document): assert deleted_get_payload["code"] != 0, deleted_get_payload +@pytest.mark.p1 +def test_chunk_add_requires_auth(create_document): + dataset_id, document_id = create_document("chunk_add_auth.txt") + path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.post(path, json={"content": "chunk test"}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p1 +def test_chunk_delete_requires_auth(create_document): + dataset_id, document_id = create_document("chunk_delete_auth.txt") + path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.delete(path, json={"chunk_ids": []}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p1 +def test_chunk_list_requires_auth(create_document): + dataset_id, document_id = create_document("chunk_list_auth.txt") + path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.get(path) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + @pytest.mark.p2 def test_chunks_add_requires_content(rest_client, create_document): dataset_id, document_id = create_document("chunk_requires_content.txt") @@ -101,6 +177,156 @@ def test_chunks_add_requires_content(rest_client, create_document): assert payload["message"] == "`content` is required", payload +@pytest.mark.p2 +def test_chunk_add_keyword_question_and_tag_contract(rest_client, create_document): + add_cases = [ + ( + "important keywords", + [ + ({"content": "chunk test", "important_keywords": ["a", "b", "c"]}, 0, ""), + ({"content": "chunk test", "important_keywords": [""]}, 0, ""), + ({"content": "chunk test", "important_keywords": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"content": "chunk test", "important_keywords": ["a", "a"]}, 0, ""), + ({"content": "chunk test", "important_keywords": "abc"}, 102, "`important_keywords` is required to be a list"), + ({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"), + ], + ), + ( + "questions", + [ + ({"content": "chunk test", "questions": ["a", "b", "c"]}, 0, ""), + ({"content": "chunk test", "questions": [""]}, 0, ""), + ({"content": "chunk test", "questions": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"content": "chunk test", "questions": ["a", "a"]}, 0, ""), + ({"content": "chunk test", "questions": "abc"}, 102, "`questions` is required to be a list"), + ({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"), + ], + ), + ( + "tag_kwd", + [ + ({"content": "chunk test", "tag_kwd": ["tag1", "tag2"]}, 0, ""), + ({"content": "chunk test", "tag_kwd": [""]}, 0, ""), + ({"content": "chunk test", "tag_kwd": [1]}, 102, "`tag_kwd` must be a list of strings"), + ({"content": "chunk test", "tag_kwd": ["tag", "tag"]}, 0, ""), + ({"content": "chunk test", "tag_kwd": "abc"}, 102, "`tag_kwd` is required to be a list"), + ({"content": "chunk test", "tag_kwd": 123}, 102, "`tag_kwd` is required to be a list"), + ], + ), + ] + + for group_index, (group_name, cases) in enumerate(add_cases): + dataset_id, document_id = create_document(f"chunk_add_contracts_{group_index}.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + for scenario_index, (payload, expected_code, expected_message) in enumerate(cases): + scenario_name = f"{group_name}-{scenario_index}" + before_payload = rest_client.get(base_path).json() + assert before_payload["code"] == 0, (scenario_name, before_payload) + before_total = before_payload["data"]["doc"]["chunk_count"] + + res = rest_client.post(base_path, json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code == 0: + chunk = body["data"]["chunk"] + assert chunk["dataset_id"] == dataset_id, (scenario_name, body) + assert chunk["document_id"] == document_id, (scenario_name, body) + assert chunk["content"] == payload["content"], (scenario_name, body) + if "important_keywords" in payload: + assert chunk["important_keywords"] == payload["important_keywords"], (scenario_name, body) + if "questions" in payload: + assert chunk["questions"] == [str(q).strip() for q in payload["questions"] if str(q).strip()], (scenario_name, body) + if "tag_kwd" in payload: + assert chunk["tag_kwd"] == payload["tag_kwd"], (scenario_name, body) + after_payload = rest_client.get(base_path).json() + assert after_payload["code"] == 0, (scenario_name, after_payload) + assert after_payload["data"]["doc"]["chunk_count"] == before_total + 1, (scenario_name, after_payload) + else: + assert body["message"] == expected_message, (scenario_name, body) + + +@pytest.mark.p2 +def test_chunk_add_invalid_dataset_and_document_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_invalid_targets.txt") + + invalid_dataset_res = rest_client.post( + f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks", + json={"content": "chunk test"}, + ) + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload + assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload + + invalid_document_res = rest_client.post( + f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks", + json={"content": "chunk test"}, + ) + assert invalid_document_res.status_code == 200 + invalid_document_payload = invalid_document_res.json() + assert invalid_document_payload["code"] == 102, invalid_document_payload + assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload + + +@pytest.mark.p2 +def test_chunk_add_repeated_and_deleted_document_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_repeat_deleted.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + + first_payload = rest_client.get(base_path).json() + assert first_payload["code"] == 0, first_payload + initial_count = first_payload["data"]["doc"]["chunk_count"] + + first_add_res = rest_client.post(base_path, json={"content": "chunk test"}) + second_add_res = rest_client.post(base_path, json={"content": "chunk test"}) + first_add_payload = first_add_res.json() + second_add_payload = second_add_res.json() + assert first_add_payload["code"] == 0, first_add_payload + assert second_add_payload["code"] == 0, second_add_payload + assert first_add_payload["data"]["chunk"]["id"] == second_add_payload["data"]["chunk"]["id"], (first_add_payload, second_add_payload) + + repeated_list_payload = rest_client.get(base_path).json() + assert repeated_list_payload["code"] == 0, repeated_list_payload + assert repeated_list_payload["data"]["doc"]["chunk_count"] == initial_count + 2, repeated_list_payload + assert repeated_list_payload["data"]["total"] == 1, repeated_list_payload + + delete_document_res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]}) + assert delete_document_res.status_code == 200 + delete_document_payload = delete_document_res.json() + assert delete_document_payload["code"] == 0, delete_document_payload + + add_after_delete_res = rest_client.post(base_path, json={"content": "chunk test"}) + assert add_after_delete_res.status_code == 200 + add_after_delete_payload = add_after_delete_res.json() + assert add_after_delete_payload["code"] == 102, add_after_delete_payload + assert add_after_delete_payload["message"] == f"You don't own the document {document_id}.", add_after_delete_payload + + +@pytest.mark.p2 +@pytest.mark.parametrize("count", [20]) +def test_chunk_concurrent_add_contract(rest_client, create_document, count): + dataset_id, document_id = create_document("chunk_concurrent_add.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + baseline_payload = rest_client.get(base_path).json() + assert baseline_payload["code"] == 0, baseline_payload + initial_count = baseline_payload["data"]["doc"]["chunk_count"] + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list( + executor.map( + lambda index: rest_client.post(base_path, json={"content": f"chunk test {index}"}).json(), + range(count), + ) + ) + assert len(results) == count, results + assert all(result["code"] == 0 for result in results), results + + final_payload = rest_client.get(base_path).json() + assert final_payload["code"] == 0, final_payload + assert final_payload["data"]["doc"]["chunk_count"] == initial_count + count, final_payload + + @pytest.mark.p2 def test_chunks_list_empty_document(rest_client, create_document): dataset_id, document_id = create_document("chunk_list_empty.txt") @@ -114,11 +340,478 @@ def test_chunks_list_empty_document(rest_client, create_document): @pytest.mark.p2 -def test_chunks_delete_partial_invalid(rest_client, create_document): - dataset_id, document_id = create_document("chunk_delete_partial.txt") +def test_chunk_delete_basic_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_delete_basic.txt") base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" - delete_res = rest_client.delete(base_path, json={"chunk_ids": ["invalid_chunk_id"]}) - assert delete_res.status_code == 200 - delete_payload = delete_res.json() - assert delete_payload["code"] == 102, delete_payload - assert "expect 1" in delete_payload["message"], delete_payload + + cases = [ + ("none payload", None, 0, "", 5), + ("invalid only", {"chunk_ids": ["invalid_id"]}, 102, "rm_chunk deleted chunks 0, expect 1", 5), + ("delete first", lambda ids: {"chunk_ids": ids[:1]}, 0, "", 4), + ("delete generated", lambda ids: {"chunk_ids": ids}, 0, "", 1), + ("empty ids", {"chunk_ids": []}, 0, "", 5), + ] + + for scenario_name, payload, expected_code, expected_message, remaining in cases: + _reset_chunk_batch(rest_client, base_path) + request_body = payload + generated_ids = rest_client.get(base_path).json()["data"]["chunks"][1:] + generated_ids = [chunk["id"] for chunk in generated_ids] + if callable(payload): + request_body = payload(generated_ids) + res = rest_client.delete(base_path, json=request_body) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_message: + assert body.get("message", "") == expected_message, (scenario_name, body) + + list_payload = rest_client.get(base_path).json() + assert list_payload["code"] == 0, (scenario_name, list_payload) + assert len(list_payload["data"]["chunks"]) == remaining, (scenario_name, list_payload) + assert list_payload["data"]["total"] == remaining, (scenario_name, list_payload) + + +@pytest.mark.p2 +def test_chunk_delete_partial_duplicate_repeat_and_invalid_target_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_delete_detail.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + + for scenario_name, payload_builder in ( + ("invalid first", lambda ids: {"chunk_ids": ["invalid_id"] + ids}), + ("invalid middle", lambda ids: {"chunk_ids": ids[:1] + ["invalid_id"] + ids[1:]}), + ("invalid last", lambda ids: {"chunk_ids": ids + ["invalid_id"]}), + ): + _, generated_ids = _reset_chunk_batch(rest_client, base_path) + res = rest_client.delete(base_path, json=payload_builder(generated_ids)) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == 102, (scenario_name, body) + assert body["message"] == "rm_chunk deleted chunks 4, expect 5", (scenario_name, body) + + list_payload = rest_client.get(base_path).json() + assert list_payload["code"] == 0, (scenario_name, list_payload) + assert list_payload["data"]["total"] == 1, (scenario_name, list_payload) + + _, generated_ids = _reset_chunk_batch(rest_client, base_path) + duplicate_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids * 2}) + assert duplicate_res.status_code == 200 + duplicate_payload = duplicate_res.json() + assert duplicate_payload["code"] == 0, duplicate_payload + assert duplicate_payload["data"]["success_count"] == 4, duplicate_payload + assert len(duplicate_payload["data"]["errors"]) == 4, duplicate_payload + assert all(error.startswith("Duplicate chunk ids: ") for error in duplicate_payload["data"]["errors"]), duplicate_payload + duplicate_list_payload = rest_client.get(base_path).json() + assert duplicate_list_payload["code"] == 0, duplicate_list_payload + assert duplicate_list_payload["data"]["total"] == 1, duplicate_list_payload + + _, generated_ids = _reset_chunk_batch(rest_client, base_path) + first_delete_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids}) + assert first_delete_res.status_code == 200 + assert first_delete_res.json()["code"] == 0 + second_delete_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids}) + assert second_delete_res.status_code == 200 + second_delete_payload = second_delete_res.json() + assert second_delete_payload["code"] == 102, second_delete_payload + assert second_delete_payload["message"] == "rm_chunk deleted chunks 0, expect 4", second_delete_payload + + invalid_dataset_res = rest_client.delete( + f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks", + json={"chunk_ids": ["chunk-id"]}, + ) + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload + assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload + + invalid_document_res = rest_client.delete( + f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks", + json={"chunk_ids": ["chunk-id"]}, + ) + assert invalid_document_res.status_code == 200 + invalid_document_payload = invalid_document_res.json() + assert invalid_document_payload["code"] == 102, invalid_document_payload + assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload + + +@pytest.mark.p2 +def test_chunk_delete_web_legacy_basic_variants(rest_client, create_document): + dataset_id, document_id = create_document("chunk_delete_web_legacy_again.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + cases = [ + ("web invalid id", {"chunk_ids": ["invalid_id"]}, 102, 5), + ("web delete first", lambda ids: {"chunk_ids": ids[:1]}, 0, 4), + ("web delete generated", lambda ids: {"chunk_ids": ids}, 0, 1), + ("web empty ids", {"chunk_ids": []}, 0, 5), + ] + for scenario_name, payload, expected_code, remaining in cases: + _, generated_ids = _reset_chunk_batch(rest_client, base_path) + request_body = payload(generated_ids) if callable(payload) else payload + res = rest_client.delete(base_path, json=request_body) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + list_payload = rest_client.get(base_path).json() + assert list_payload["code"] == 0, (scenario_name, list_payload) + assert list_payload["data"]["total"] == remaining, (scenario_name, list_payload) + + +@pytest.mark.p2 +def test_chunk_delete_concurrent_and_bulk_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_delete_bulk_contract.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + + rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True}) + for index in range(12): + payload = rest_client.post(base_path, json={"content": f"chunk test {index}"}).json() + assert payload["code"] == 0, payload + ids_payload = rest_client.get(base_path).json() + assert ids_payload["code"] == 0, ids_payload + chunk_ids = [chunk["id"] for chunk in ids_payload["data"]["chunks"]] + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(lambda chunk_id: rest_client.delete(base_path, json={"chunk_ids": [chunk_id]}).json(), chunk_ids)) + assert len(results) == len(chunk_ids), results + assert all(result["code"] == 0 for result in results), results + + final_payload = rest_client.get(base_path).json() + assert final_payload["code"] == 0, final_payload + assert final_payload["data"]["total"] == 0, final_payload + + rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True}) + for index in range(40): + payload = rest_client.post(base_path, json={"content": f"bulk chunk {index}"}).json() + assert payload["code"] == 0, payload + bulk_ids_payload = rest_client.get(base_path, params={"page_size": 200}).json() + assert bulk_ids_payload["code"] == 0, bulk_ids_payload + bulk_ids = [chunk["id"] for chunk in bulk_ids_payload["data"]["chunks"]] + bulk_res = rest_client.delete(base_path, json={"chunk_ids": bulk_ids}) + assert bulk_res.status_code == 200 + bulk_payload = bulk_res.json() + assert bulk_payload["code"] == 0, bulk_payload + + +@pytest.mark.p2 +def test_chunk_list_default_get_id_and_invalid_target_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_core.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + baseline_id, generated_ids = _reset_chunk_batch(rest_client, base_path) + + default_res = rest_client.get(base_path) + assert default_res.status_code == 200 + default_payload = default_res.json() + assert default_payload["code"] == 0, default_payload + assert default_payload["data"]["total"] == 5, default_payload + assert len(default_payload["data"]["chunks"]) == 5, default_payload + + get_res = rest_client.get(f"{base_path}/{generated_ids[0]}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == generated_ids[0], get_payload + assert get_payload["data"]["doc_id"] == document_id, get_payload + + invalid_get_res = rest_client.get(f"{base_path}/unknown") + assert invalid_get_res.status_code == 200 + invalid_get_payload = invalid_get_res.json() + assert invalid_get_payload["code"] == 102, invalid_get_payload + assert invalid_get_payload["message"] == "Chunk not found!", invalid_get_payload + + id_cases = [ + ("id none", {"id": None}, 0, 5, None), + ("id empty", {"id": ""}, 0, 5, None), + ("id valid", {"id": generated_ids[0]}, 0, 1, generated_ids[0]), + ("id invalid", {"id": "unknown"}, 102, None, None), + ] + for scenario_name, params, expected_code, expected_total, expected_id in id_cases: + res = rest_client.get(base_path, params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + assert payload["data"]["total"] == expected_total, (scenario_name, payload) + if expected_id is not None: + assert payload["data"]["chunks"][0]["id"] == expected_id, (scenario_name, payload) + else: + assert payload["message"] == f"Chunk not found: {dataset_id}/unknown", (scenario_name, payload) + + invalid_dataset_res = rest_client.get(f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks") + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload + assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload + + invalid_document_res = rest_client.get(f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks") + assert invalid_document_res.status_code == 200 + invalid_document_payload = invalid_document_res.json() + assert invalid_document_payload["code"] == 102, invalid_document_payload + assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload + + +@pytest.mark.p2 +@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity") +def test_chunk_list_keyword_and_invalid_param_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_keywords.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + _reset_chunk_batch(rest_client, base_path) + + cases = [ + ("keywords none", {"keywords": None}, 5), + ("keywords empty", {"keywords": ""}, 5), + ("keywords exact one", {"keywords": "1"}, 1), + ("keywords chunk", {"keywords": "chunk"}, 4), + ("keywords ragflow", {"keywords": "ragflow"}, 1), + ("keywords unknown", {"keywords": "unknown"}, 0), + ("invalid params ignored", {"a": "b"}, 5), + ] + + for scenario_name, params, expected_total in cases: + res = rest_client.get(base_path, params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 0, (scenario_name, payload) + assert payload["data"]["total"] == expected_total, (scenario_name, payload) + assert len(payload["data"]["chunks"]) == expected_total, (scenario_name, payload) + + +@pytest.mark.p2 +@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity") +def test_chunk_list_page_and_page_size_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_paging.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + _reset_chunk_batch(rest_client, base_path) + + cases = [ + ("page none", {"page": None, "page_size": 2}, 0, 2, ""), + ("page zero", {"page": 0, "page_size": 2}, 100, None, "ValueError('Search does not support negative slicing.')"), + ("page two", {"page": 2, "page_size": 2}, 0, 2, ""), + ("page three", {"page": 3, "page_size": 2}, 0, 1, ""), + ("page string", {"page": "3", "page_size": 2}, 0, 1, ""), + ("page negative", {"page": -1, "page_size": 2}, 100, None, "ValueError('Search does not support negative slicing.')"), + ("page alpha", {"page": "a", "page_size": 2}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"), + ("page_size none", {"page_size": None}, 0, 5, ""), + ("page_size zero", {"page_size": 0}, 0, 5, ""), + ("page_size one", {"page_size": 1}, 0, 1, ""), + ("page_size six", {"page_size": 6}, 0, 5, ""), + ("page_size string", {"page_size": "1"}, 0, 1, ""), + ("page_size negative", {"page_size": -1}, 0, 5, ""), + ("page_size alpha", {"page_size": "a"}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"), + ] + + for scenario_name, params, expected_code, expected_total, expected_message in cases: + res = rest_client.get(base_path, params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + assert payload["data"]["total"] == 5, (scenario_name, payload) + assert len(payload["data"]["chunks"]) == expected_total, (scenario_name, payload) + else: + assert expected_message in payload["message"], (scenario_name, payload) + + +@pytest.mark.p2 +def test_chunk_list_concurrent_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_concurrent.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + _reset_chunk_batch(rest_client, base_path) + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(lambda _: rest_client.get(base_path).json(), range(20))) + assert len(results) == 20, results + assert all(result["code"] == 0 for result in results), results + assert all(result["data"]["total"] == 5 for result in results), results + + +def _create_chunk_for_update(rest_client, create_document, file_name): + dataset_id, document_id = create_document(file_name) + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + add_res = rest_client.post(base_path, json={"content": "chunk update test"}) + assert add_res.status_code == 200, add_res.text + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_id = add_payload["data"]["chunk"]["id"] + return dataset_id, document_id, chunk_id, base_path + + +@pytest.mark.p2 +def test_chunk_update_requires_auth(rest_client, create_document): + _, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_auth.txt") + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.patch(f"{base_path}/{chunk_id}", json={"content": "updated"}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p2 +def test_chunk_update_content_and_available_contract(rest_client, create_document): + content_cases = [ + ("content none", {"content": None}, 0, ""), + ("content empty", {"content": ""}, 102, "`content` is required"), + ("content text", {"content": "update chunk"}, 0, ""), + ("content spaces", {"content": " "}, 102, "`content` is required"), + ("content punctuation", {"content": "\n!?。;!?\"'"}, 0, ""), + ] + for scenario_name, payload, expected_code, expected_message in content_cases: + _, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, f"{scenario_name}.txt") + res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert body["message"] == expected_message, (scenario_name, body) + + available_cases = [ + ("available true", {"available": True}, 0, ""), + ("available true str", {"available": "True"}, 100, "invalid literal for int()"), + ("available one", {"available": 1}, 0, ""), + ("available false", {"available": False}, 0, ""), + ("available false str", {"available": "False"}, 100, "invalid literal for int()"), + ("available zero", {"available": 0}, 0, ""), + ] + for scenario_name, payload, expected_code, expected_message in available_cases: + _, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, f"{scenario_name}.txt") + res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert expected_message in body["message"], (scenario_name, body) + + +@pytest.mark.p2 +def test_chunk_update_keywords_questions_and_tag_contract(rest_client, create_document): + _, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_fields.txt") + cases = [ + ("important keywords", {"important_keywords": ["a", "b", "c"]}, 0, ""), + ("important keywords empty", {"important_keywords": [""]}, 0, ""), + ("important keywords int", {"important_keywords": [1]}, 100, "TypeError"), + ("important keywords dup", {"important_keywords": ["a", "a"]}, 0, ""), + ("important keywords str", {"important_keywords": "abc"}, 102, "`important_keywords` should be a list"), + ("important keywords number", {"important_keywords": 123}, 102, "`important_keywords` should be a list"), + ("questions", {"questions": ["a", "b", "c"]}, 0, ""), + ("questions empty", {"questions": [""]}, 0, ""), + ("questions int", {"questions": [1]}, 100, "TypeError"), + ("questions dup", {"questions": ["a", "a"]}, 0, ""), + ("questions str", {"questions": "abc"}, 102, "`questions` should be a list"), + ("questions number", {"questions": 123}, 102, "`questions` should be a list"), + ("tag kwd", {"tag_kwd": ["tag1", "tag2"]}, 0, ""), + ("tag kwd empty", {"tag_kwd": [""]}, 0, ""), + ("tag kwd int in list", {"tag_kwd": [1]}, 102, "`tag_kwd` must be a list of strings"), + ("tag kwd dup", {"tag_kwd": ["tag", "tag"]}, 0, ""), + ("tag kwd str", {"tag_kwd": "tag"}, 102, "`tag_kwd` should be a list"), + ("tag kwd number", {"tag_kwd": 123}, 102, "`tag_kwd` should be a list"), + ] + for scenario_name, payload, expected_code, expected_message in cases: + res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert expected_message in body["message"], (scenario_name, body) + + +@pytest.mark.p2 +def test_chunk_update_invalid_target_and_param_contract(rest_client, create_document): + dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_invalid_targets.txt") + + invalid_dataset_res = rest_client.patch( + f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks/{chunk_id}", + json={"content": "updated"}, + ) + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload + assert invalid_dataset_payload["message"] in { + f"You don't own the dataset {INVALID_ID_32}.", + f"Can't find this chunk {chunk_id}", + }, invalid_dataset_payload + + invalid_document_res = rest_client.patch( + f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks/{chunk_id}", + json={"content": "updated"}, + ) + assert invalid_document_res.status_code == 200 + invalid_document_payload = invalid_document_res.json() + assert invalid_document_payload["code"] == 102, invalid_document_payload + assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload + + invalid_chunk_res = rest_client.patch( + f"{base_path}/{INVALID_ID_32}", + json={"content": "updated"}, + ) + assert invalid_chunk_res.status_code == 200 + invalid_chunk_payload = invalid_chunk_res.json() + assert invalid_chunk_payload["code"] == 102, invalid_chunk_payload + assert invalid_chunk_payload["message"] == f"Can't find this chunk {INVALID_ID_32}", invalid_chunk_payload + + for scenario_name, payload in ( + ("unknown key", {"unknown_key": "unknown_value"}), + ("empty payload", {}), + ): + res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == 0, (scenario_name, body) + + +@pytest.mark.p2 +def test_chunk_update_repeated_concurrent_and_deleted_document_contract(rest_client, create_document): + dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update( + rest_client, create_document, "chunk_update_repeated_concurrent_deleted.txt" + ) + + first_res = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "chunk test 1"}) + assert first_res.status_code == 200 + assert first_res.json()["code"] == 0 + + second_res = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "chunk test 2"}) + assert second_res.status_code == 200 + assert second_res.json()["code"] == 0 + + get_after_repeat = rest_client.get(f"{base_path}/{chunk_id}") + assert get_after_repeat.status_code == 200 + get_after_repeat_payload = get_after_repeat.json() + assert get_after_repeat_payload["code"] == 0, get_after_repeat_payload + assert get_after_repeat_payload["data"]["content_with_weight"] == "chunk test 2", get_after_repeat_payload + + chunk_ids = [chunk_id] + for index in range(3): + add_res = rest_client.post(base_path, json={"content": f"concurrent update {index}"}) + assert add_res.status_code == 200, add_res.text + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_ids.append(add_payload["data"]["chunk"]["id"]) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [] + for index in range(20): + target_id = chunk_ids[index % len(chunk_ids)] + futures.append( + executor.submit( + lambda cid, i: rest_client.patch( + f"{base_path}/{cid}", + json={"content": f"update chunk test {i}"}, + ).json(), + target_id, + index, + ) + ) + results = [future.result() for future in futures] + assert len(results) == 20, results + assert all(item["code"] == 0 for item in results), results + + delete_document_res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]}) + assert delete_document_res.status_code == 200 + assert delete_document_res.json()["code"] == 0 + + update_after_delete = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "after delete"}) + assert update_after_delete.status_code == 200 + update_after_delete_payload = update_after_delete.json() + assert update_after_delete_payload["code"] == 102, update_after_delete_payload + assert update_after_delete_payload["message"] in { + f"You don't own the document {document_id}.", + f"Can't find this chunk {chunk_id}", + }, update_after_delete_payload diff --git a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py new file mode 100644 index 000000000..3187846a7 --- /dev/null +++ b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py @@ -0,0 +1,425 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import inspect +import sys +from copy import deepcopy +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _DummyKB: + def __init__(self, tenant_id="tenant-1", embd_id="embd-1", tenant_embd_id=1): + self.tenant_id = tenant_id + self.embd_id = embd_id + self.tenant_embd_id = tenant_embd_id + + +class _DummyRetriever: + async def retrieval(self, *_args, **_kwargs): + return { + "chunks": [ + {"doc_id": "doc-1", "content_with_weight": "chunk-content", "similarity": 0.8, "docnm_kwd": "doc-title", "vector": [0.1]} + ] + } + + def retrieval_by_children(self, chunks, _tenant_ids): + return chunks + + +def _run(coro): + return asyncio.run(coro) + + +def _load_dify_retrieval_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + deepdoc_pkg = ModuleType("deepdoc") + deepdoc_parser_pkg = ModuleType("deepdoc.parser") + deepdoc_parser_pkg.__path__ = [] + + class _StubPdfParser: + pass + + class _StubExcelParser: + pass + + class _StubDocxParser: + pass + + deepdoc_parser_pkg.PdfParser = _StubPdfParser + deepdoc_parser_pkg.ExcelParser = _StubExcelParser + deepdoc_parser_pkg.DocxParser = _StubDocxParser + deepdoc_pkg.parser = deepdoc_parser_pkg + monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg) + monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg) + + deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser") + deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser + monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module) + + deepdoc_parser_utils = ModuleType("deepdoc.parser.utils") + deepdoc_parser_utils.get_text = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils) + monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost")) + + tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") + + class _MockModelConfig: + def __init__(self, tenant_id, model_name): + self.tenant_id = tenant_id + self.llm_name = model_name + self.llm_factory = "Builtin" + self.api_key = "fake-api-key" + self.api_base = "https://api.example.com" + self.model_type = "chat" + self.max_tokens = 8192 + self.used_tokens = 0 + self.status = 1 + self.id = 1 + + def to_dict(self): + return { + "tenant_id": self.tenant_id, + "llm_name": self.llm_name, + "llm_factory": self.llm_factory, + "api_key": self.api_key, + "api_base": self.api_base, + "model_type": self.model_type, + "max_tokens": self.max_tokens, + "used_tokens": self.used_tokens, + "status": self.status, + "id": self.id, + } + + class _StubTenantService: + @staticmethod + def get_by_id(tenant_id): + return True, SimpleNamespace( + id=tenant_id, + llm_id="chat-model", + embd_id="embd-model", + asr_id="asr-model", + img2txt_id="img2txt-model", + rerank_id="rerank-model", + tts_id="tts-model", + ) + + class _StubTenantLLMService: + @staticmethod + def get_api_key(tenant_id, model_name): + return _MockModelConfig(tenant_id, model_name) + + @staticmethod + def split_model_name_and_factory(model_name): + if "@" in model_name: + parts = model_name.split("@") + return parts[0], parts[1] + return model_name, None + + tenant_llm_service_mod.TenantService = _StubTenantService + tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService + + class _StubLLMFactoriesService: + pass + + tenant_llm_service_mod.LLMFactoriesService = _StubLLMFactoriesService + monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + + llm_service_mod = ModuleType("api.db.services.llm_service") + + class _StubLLM: + def __init__(self, llm_name): + self.llm_name = llm_name + self.is_tools = False + + class _StubLLMBundle: + def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs): + self.tenant_id = tenant_id + self.model_config = model_config + self.lang = lang + + def encode(self, texts: list): + import numpy as np + + return [np.array([0.1, 0.2, 0.3]) for _ in texts], len(texts) * 10 + + llm_service_mod.LLMService = SimpleNamespace(query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []) + llm_service_mod.LLMBundle = _StubLLMBundle + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) + + tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") + + class _MockModelConfig2: + def __init__(self, tenant_id, model_name): + self.tenant_id = tenant_id + self.llm_name = model_name + self.llm_factory = "Builtin" + self.api_key = "fake-api-key" + self.api_base = "https://api.example.com" + self.model_type = "chat" + self.max_tokens = 8192 + self.used_tokens = 0 + self.status = 1 + self.id = 1 + + def to_dict(self): + return { + "tenant_id": self.tenant_id, + "llm_name": self.llm_name, + "llm_factory": self.llm_factory, + "api_key": self.api_key, + "api_base": self.api_base, + "model_type": self.model_type, + "max_tokens": self.max_tokens, + "used_tokens": self.used_tokens, + "status": self.status, + "id": self.id, + } + + def _get_model_config_by_id(tenant_model_id: int, allowed_tenant_ids=None, requester_tenant_id=None) -> dict: + mock_tenant_id = "tenant-1" + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") + return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() + + def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): + if not model_name: + raise Exception("Model Name is required") + return _MockModelConfig2(tenant_id, model_name).to_dict() + + def _get_tenant_default_model_by_type(tenant_id: str, model_type): + return _MockModelConfig2(tenant_id, "chat-model").to_dict() + + tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id + tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name + tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type + monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) + + module_name = "test_dify_retrieval_routes_unit_module" + module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +def _set_request_json(monkeypatch, module, payload): + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload))) + + +@pytest.mark.p2 +def test_retrieval_success_with_metadata_and_kg(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + _set_request_json( + monkeypatch, + module, + { + "knowledge_id": "kb-1", + "query": "hello", + "use_kg": True, + "retrieval_setting": {"score_threshold": 0.1, "top_k": 3}, + "metadata_condition": {"conditions": [{"name": "author", "comparison_operator": "is", "value": "alice"}], "logic": "and"}, + }, + ) + + monkeypatch.setattr(module, "jsonify", lambda payload: payload) + monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [{"doc_id": "doc-1"}]) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", [])) + monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: []) + + retriever = _DummyRetriever() + monkeypatch.setattr(module.settings, "retriever", retriever) + + class _DummyKgRetriever: + async def retrieval(self, *_args, **_kwargs): + return { + "doc_id": "doc-2", + "content_with_weight": "kg-content", + "similarity": 0.9, + "docnm_kwd": "kg-title", + } + + monkeypatch.setattr(module.settings, "kg_retriever", _DummyKgRetriever()) + monkeypatch.setattr(module.DocumentService, "get_by_id", lambda doc_id: (True, SimpleNamespace(meta_fields={"origin": f"meta-{doc_id}"}))) + monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) + + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert "records" in res, res + assert len(res["records"]) == 2, res + top = res["records"][0] + assert top["title"] == "kg-title", res + assert top["metadata"]["doc_id"] == "doc-2", res + assert "score" in top, res + + +@pytest.mark.p2 +def test_retrieval_kb_not_found(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-missing", "query": "hello"}) + monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) + + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.NOT_FOUND, res + assert "Knowledgebase not found" in res["message"], res + + +@pytest.mark.p2 +def test_retrieval_not_found_exception_mapping(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"}) + monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) + + class _BrokenRetriever: + async def retrieval(self, *_args, **_kwargs): + raise RuntimeError("chunk_not_found_error") + + monkeypatch.setattr(module.settings, "retriever", _BrokenRetriever()) + + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.NOT_FOUND, res + assert "No chunk found" in res["message"], res + + +@pytest.mark.p2 +def test_retrieval_generic_exception_mapping(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"}) + monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) + + class _BrokenRetriever: + async def retrieval(self, *_args, **_kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(module.settings, "retriever", _BrokenRetriever()) + + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.SERVER_ERROR, res + assert "boom" in res["message"], res + + +@pytest.mark.p2 +def test_read_retrieval_request_from_get_args(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + method="GET", + args={ + "knowledge_id": "kb-1", + "query": "hello", + "use_kg": "true", + "top_k": "12", + "score_threshold": "0.66", + }, + ), + ) + + req = _run(module._read_retrieval_request()) + assert req["knowledge_id"] == "kb-1", req + assert req["query"] == "hello", req + assert req["use_kg"] is True, req + assert req["retrieval_setting"]["top_k"] == 12, req + assert req["retrieval_setting"]["score_threshold"] == 0.66, req + + +@pytest.mark.p2 +def test_read_retrieval_request_from_post_json(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + payload = {"knowledge_id": "kb-1", "query": "hello"} + monkeypatch.setattr(module, "request", SimpleNamespace(method="POST", args={})) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload)) + + req = _run(module._read_retrieval_request()) + assert req == payload, req + + +@pytest.mark.p2 +def test_retrieval_argument_error_messages(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + + _set_request_json( + monkeypatch, + module, + { + "knowledge_id": "kb-1", + "query": "hello", + "retrieval_setting": {"top_k": "not-int", "score_threshold": "not-float"}, + }, + ) + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + assert "invalid or malformed arguments:" in res["message"], res + + _set_request_json(monkeypatch, module, {}) + res_missing = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_missing["code"] == module.RetCode.ARGUMENT_ERROR, res_missing + assert "required arguments are missing:" in res_missing["message"], res_missing + + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1"}) + res_missing_query = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_missing_query["code"] == module.RetCode.ARGUMENT_ERROR, res_missing_query + assert "query" in res_missing_query["message"], res_missing_query + + _set_request_json( + monkeypatch, + module, + {"knowledge_id": "kb-1", "query": "hello", "retrieval_setting": "bad-type"}, + ) + res_wrong_type = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_wrong_type["code"] == module.RetCode.ARGUMENT_ERROR, res_wrong_type + assert "retrieval_setting must be an object" in res_wrong_type["message"], res_wrong_type diff --git a/test/testcases/restful_api/test_retrieval.py b/test/testcases/restful_api/test_retrieval.py index bce37c4cd..5f6531a8c 100644 --- a/test/testcases/restful_api/test_retrieval.py +++ b/test/testcases/restful_api/test_retrieval.py @@ -14,7 +14,12 @@ # limitations under the License. # +from concurrent.futures import ThreadPoolExecutor import pytest +import requests +from test.testcases.configs import HOST_ADDRESS, INVALID_API_TOKEN, VERSION +from test.testcases.restful_api.helpers.client import RestClient +from test.testcases.utils import wait_for @pytest.mark.p1 @@ -107,3 +112,264 @@ def test_retrieval_compatibility_requires_auth(rest_client_noauth): # token_required preserves legacy payload code/message while returning HTTP 401. assert payload["code"] == 0, payload assert payload["message"] == "`Authorization` can't be empty", payload + + +@wait_for(20, 1, "Retrieval indexing timeout in RESTful batch 10 tests") +def _retrieval_has_question(rest_client, dataset_id, question): + res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]}) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + return len(payload["data"]["chunks"]) > 0 + + +@wait_for(20, 1, "Retrieval indexing timeout waiting for chunk presence in RESTful batch 10 tests") +def _retrieval_has_chunks(rest_client, dataset_id, question, chunk_ids): + res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]}) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + retrieved_ids = {chunk["id"] for chunk in payload["data"]["chunks"]} + expected_ids = set(chunk_ids) + return expected_ids.issubset(retrieved_ids) + + +@wait_for(20, 1, "Retrieval indexing timeout waiting for chunk deletion in RESTful batch 10 tests") +def _retrieval_lacks_chunks(rest_client, dataset_id, question, chunk_ids): + res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]}) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + retrieved_ids = {chunk["id"] for chunk in payload["data"]["chunks"]} + expected_ids = set(chunk_ids) + return expected_ids.isdisjoint(retrieved_ids) + + +@pytest.mark.p2 +def test_retrieval_requires_auth_contract(ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + for scenario_name, token, expected_code, expected_message in ( + ("missing token", None, 0, "`Authorization` can't be empty"), + ("invalid token", INVALID_API_TOKEN, 109, "Authentication error: API key is invalid!"), + ): + client = RestClient(token=token) + res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": [dataset_id]}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + assert payload["message"] == expected_message, (scenario_name, payload) + + +@pytest.mark.p2 +def test_retrieval_page_and_page_size_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + cases = [ + ("page none", {"question": "chunk", "dataset_ids": [dataset_id], "page": None, "page_size": 2}, 100, "TypeError"), + ("page zero", {"question": "chunk", "dataset_ids": [dataset_id], "page": 0, "page_size": 2}, 0, ""), + ("page two", {"question": "chunk", "dataset_ids": [dataset_id], "page": 2, "page_size": 2}, 0, ""), + ("page three", {"question": "chunk", "dataset_ids": [dataset_id], "page": 3, "page_size": 2}, 0, ""), + ("page str", {"question": "chunk", "dataset_ids": [dataset_id], "page": "3", "page_size": 2}, 0, ""), + ("page negative", {"question": "chunk", "dataset_ids": [dataset_id], "page": -1, "page_size": 2}, 0, ""), + ("page alpha", {"question": "chunk", "dataset_ids": [dataset_id], "page": "a", "page_size": 2}, 100, "invalid literal for int()"), + ("page_size none", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": None}, 100, "TypeError"), + ("page_size one", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": 1}, 0, ""), + ("page_size five", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": 5}, 0, ""), + ("page_size str", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": "1"}, 0, ""), + ("page_size alpha", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": "a"}, 100, "invalid literal for int()"), + ] + for scenario_name, payload, expected_code, expected_message in cases: + res = rest_client.post("/retrieval", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert expected_message in body["message"], (scenario_name, body) + + +@pytest.mark.p2 +def test_retrieval_highlight_keyword_and_invalid_params_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + + highlight_cases = [ + ("highlight true", True, True), + ("highlight true str", "True", True), + ("highlight false", False, False), + ("highlight false str", "False", False), + ("highlight none", None, False), + ] + for scenario_name, highlight_value, expect_highlight in highlight_cases: + res = rest_client.post( + "/retrieval", + json={"question": "chunk", "dataset_ids": [dataset_id], "highlight": highlight_value}, + ) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == 0, (scenario_name, body) + for chunk in body["data"]["chunks"]: + if expect_highlight: + assert "highlight" in chunk, (scenario_name, body) + else: + assert "highlight" not in chunk, (scenario_name, body) + + invalid_highlight = rest_client.post( + "/retrieval", + json={"question": "chunk", "dataset_ids": [dataset_id], "highlight": "not_bool"}, + ) + assert invalid_highlight.status_code == 200 + invalid_highlight_payload = invalid_highlight.json() + assert invalid_highlight_payload["code"] == 102, invalid_highlight_payload + assert invalid_highlight_payload["message"] == "`highlight` should be a boolean", invalid_highlight_payload + + for scenario_name, keyword_value in ( + ("keyword true", True), + ("keyword true str", "True"), + ("keyword false", False), + ("keyword false str", "False"), + ("keyword none", None), + ): + keyword_res = rest_client.post( + "/retrieval", + json={"question": "chunk test", "dataset_ids": [dataset_id], "keyword": keyword_value}, + ) + assert keyword_res.status_code == 200, (scenario_name, keyword_res.text) + keyword_payload = keyword_res.json() + assert keyword_payload["code"] == 0, (scenario_name, keyword_payload) + assert isinstance(keyword_payload["data"]["chunks"], list), (scenario_name, keyword_payload) + + invalid_params_res = rest_client.post( + "/retrieval", + json={"question": "chunk", "dataset_ids": [dataset_id], "a": "b"}, + ) + assert invalid_params_res.status_code == 200 + invalid_params_payload = invalid_params_res.json() + assert invalid_params_payload["code"] == 0, invalid_params_payload + + +@pytest.mark.p2 +def test_retrieval_vector_similarity_and_top_k_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + cases = [ + ("vector 0", {"vector_similarity_weight": 0}, 0, ""), + ("vector 0.5", {"vector_similarity_weight": 0.5}, 0, ""), + ("vector 10", {"vector_similarity_weight": 10}, 0, ""), + ("vector alpha", {"vector_similarity_weight": "a"}, 100, "could not convert string to float"), + ("top_k 10", {"top_k": 10}, 0, ""), + ("top_k 1", {"top_k": 1}, 0, ""), + ("top_k -1", {"top_k": -1}, 102, "`top_k` must be greater than 0"), + ("top_k alpha", {"top_k": "a"}, 100, "invalid literal for int()"), + ] + for scenario_name, updates, expected_code, expected_message in cases: + payload = {"question": "chunk", "dataset_ids": [dataset_id]} + payload.update(updates) + res = rest_client.post("/retrieval", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert expected_message in body["message"], (scenario_name, body) + + +@pytest.mark.p2 +def test_retrieval_rerank_unknown_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + res = rest_client.post( + "/retrieval", + json={"question": "chunk", "dataset_ids": [dataset_id], "rerank_id": "unknown"}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] != 0, payload + assert payload["message"], payload + + +@pytest.mark.p2 +def test_retrieval_concurrent_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + payload = {"question": "chunk", "dataset_ids": [dataset_id]} + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(lambda _: rest_client.post("/retrieval", json=payload).json(), range(20))) + assert len(results) == 20, results + assert all(result["code"] == 0 for result in results), results + + +@pytest.mark.p2 +def test_deleted_chunk_not_in_retrieval_contract(rest_client, create_document): + dataset_id, document_id = create_document("retrieval_deleted_chunk.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + content = "UNIQUE_TEST_CONTENT_12520_REST" + + add_res = rest_client.post(base_path, json={"content": content}) + assert add_res.status_code == 200 + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_id = add_payload["data"]["chunk"]["id"] + + _retrieval_has_chunks(rest_client, dataset_id, content, [chunk_id]) + + delete_res = rest_client.delete(base_path, json={"chunk_ids": [chunk_id]}) + assert delete_res.status_code == 200 + assert delete_res.json()["code"] == 0 + _retrieval_lacks_chunks(rest_client, dataset_id, content, [chunk_id]) + + +@pytest.mark.p2 +def test_deleted_chunks_batch_not_in_retrieval_contract(rest_client, create_document): + dataset_id, document_id = create_document("retrieval_deleted_chunks_batch.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + chunk_ids = [] + for index in range(3): + content = f"BATCH_DELETE_TEST_CHUNK_{index}_REST_12520" + add_res = rest_client.post(base_path, json={"content": content}) + assert add_res.status_code == 200 + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_ids.append(add_payload["data"]["chunk"]["id"]) + _retrieval_has_chunks(rest_client, dataset_id, "BATCH_DELETE_TEST_CHUNK", chunk_ids) + + delete_res = rest_client.delete(base_path, json={"chunk_ids": chunk_ids}) + assert delete_res.status_code == 200 + assert delete_res.json()["code"] == 0 + _retrieval_lacks_chunks(rest_client, dataset_id, "BATCH_DELETE_TEST_CHUNK", chunk_ids) + + +@pytest.mark.p2 +def test_related_questions_contract(auth, rest_client, rest_client_noauth): + tokens_res = requests.get( + f"{HOST_ADDRESS}/api/{VERSION}/system/tokens", + headers={"Authorization": auth}, + timeout=30, + ) + assert tokens_res.status_code == 200, tokens_res.text + tokens_payload = tokens_res.json() + assert tokens_payload["code"] == 0, tokens_payload + assert tokens_payload["data"], tokens_payload + beta_token = tokens_payload["data"][0]["beta"] + + success_client = RestClient(token=beta_token) + success_res = success_client.post("/searchbots/related_questions", json={"question": "ragflow", "industry": "search"}) + assert success_res.status_code == 200 + success_payload = success_res.json() + assert success_payload["code"] == 0, success_payload + assert isinstance(success_payload["data"], list), success_payload + + missing_res = rest_client.post("/searchbots/related_questions", json={"industry": "search"}) + assert missing_res.status_code == 200 + missing_payload = missing_res.json() + assert missing_payload["code"] == 101, missing_payload + assert "question" in missing_payload["message"], missing_payload + + invalid_auth_res = rest_client_noauth.post( + "/searchbots/related_questions", + json={"question": "ragflow", "industry": "search"}, + headers={"Authorization": "invalid"}, + ) + assert invalid_auth_res.status_code == 200 + invalid_auth_payload = invalid_auth_res.json() + assert invalid_auth_payload["code"] == 102, invalid_auth_payload + assert "Authorization is not valid!" in invalid_auth_payload["message"], invalid_auth_payload