mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-30 04:27:30 +08:00
Feat: add new tests and tescases for restful api suite (#15038)
### What problem does this PR solve? extend restful api suite ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Other (please describe): test
This commit is contained in:
@ -14,7 +14,12 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
import pytest
|
||||
from test.testcases.configs import INVALID_API_TOKEN, INVALID_ID_32
|
||||
from test.testcases.restful_api.helpers.client import RestClient
|
||||
from test.testcases.utils import wait_for
|
||||
|
||||
|
||||
def _assert_created_chunk_id(payload):
|
||||
@ -25,6 +30,41 @@ def _assert_created_chunk_id(payload):
|
||||
return chunk_id
|
||||
|
||||
|
||||
@wait_for(10, 1, "Chunk indexing timeout in RESTful batch 09 tests")
|
||||
def _chunk_count(rest_client, base_path, expected_total):
|
||||
res = rest_client.get(base_path)
|
||||
if res.status_code != 200:
|
||||
return False
|
||||
payload = res.json()
|
||||
if payload["code"] != 0:
|
||||
return False
|
||||
return payload["data"]["total"] == expected_total and len(payload["data"]["chunks"]) == min(expected_total, 30)
|
||||
|
||||
|
||||
def _reset_chunk_batch(rest_client, base_path, count=4):
|
||||
cleanup_res = rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True})
|
||||
assert cleanup_res.status_code == 200, cleanup_res.text
|
||||
cleanup_payload = cleanup_res.json()
|
||||
assert cleanup_payload["code"] == 0, cleanup_payload
|
||||
|
||||
baseline_res = rest_client.post(base_path, json={"content": "ragflow test upload"})
|
||||
assert baseline_res.status_code == 200, baseline_res.text
|
||||
baseline_payload = baseline_res.json()
|
||||
assert baseline_payload["code"] == 0, baseline_payload
|
||||
baseline_id = _assert_created_chunk_id(baseline_payload)
|
||||
|
||||
chunk_ids = []
|
||||
for index in range(count):
|
||||
res = rest_client.post(base_path, json={"content": f"chunk test {index}"})
|
||||
assert res.status_code == 200, (index, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, (index, payload)
|
||||
chunk_ids.append(_assert_created_chunk_id(payload))
|
||||
|
||||
_chunk_count(rest_client, base_path, count + 1)
|
||||
return baseline_id, chunk_ids
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_cycle.txt")
|
||||
@ -88,6 +128,42 @@ def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document):
|
||||
assert deleted_get_payload["code"] != 0, deleted_get_payload
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_chunk_add_requires_auth(create_document):
|
||||
dataset_id, document_id = create_document("chunk_add_auth.txt")
|
||||
path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
|
||||
res = client.post(path, json={"content": "chunk test"})
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, (scenario_name, payload)
|
||||
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_chunk_delete_requires_auth(create_document):
|
||||
dataset_id, document_id = create_document("chunk_delete_auth.txt")
|
||||
path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
|
||||
res = client.delete(path, json={"chunk_ids": []})
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, (scenario_name, payload)
|
||||
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_chunk_list_requires_auth(create_document):
|
||||
dataset_id, document_id = create_document("chunk_list_auth.txt")
|
||||
path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
|
||||
res = client.get(path)
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, (scenario_name, payload)
|
||||
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunks_add_requires_content(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_requires_content.txt")
|
||||
@ -101,6 +177,156 @@ def test_chunks_add_requires_content(rest_client, create_document):
|
||||
assert payload["message"] == "`content` is required", payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_add_keyword_question_and_tag_contract(rest_client, create_document):
|
||||
add_cases = [
|
||||
(
|
||||
"important keywords",
|
||||
[
|
||||
({"content": "chunk test", "important_keywords": ["a", "b", "c"]}, 0, ""),
|
||||
({"content": "chunk test", "important_keywords": [""]}, 0, ""),
|
||||
({"content": "chunk test", "important_keywords": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
|
||||
({"content": "chunk test", "important_keywords": ["a", "a"]}, 0, ""),
|
||||
({"content": "chunk test", "important_keywords": "abc"}, 102, "`important_keywords` is required to be a list"),
|
||||
({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"),
|
||||
],
|
||||
),
|
||||
(
|
||||
"questions",
|
||||
[
|
||||
({"content": "chunk test", "questions": ["a", "b", "c"]}, 0, ""),
|
||||
({"content": "chunk test", "questions": [""]}, 0, ""),
|
||||
({"content": "chunk test", "questions": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
|
||||
({"content": "chunk test", "questions": ["a", "a"]}, 0, ""),
|
||||
({"content": "chunk test", "questions": "abc"}, 102, "`questions` is required to be a list"),
|
||||
({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"),
|
||||
],
|
||||
),
|
||||
(
|
||||
"tag_kwd",
|
||||
[
|
||||
({"content": "chunk test", "tag_kwd": ["tag1", "tag2"]}, 0, ""),
|
||||
({"content": "chunk test", "tag_kwd": [""]}, 0, ""),
|
||||
({"content": "chunk test", "tag_kwd": [1]}, 102, "`tag_kwd` must be a list of strings"),
|
||||
({"content": "chunk test", "tag_kwd": ["tag", "tag"]}, 0, ""),
|
||||
({"content": "chunk test", "tag_kwd": "abc"}, 102, "`tag_kwd` is required to be a list"),
|
||||
({"content": "chunk test", "tag_kwd": 123}, 102, "`tag_kwd` is required to be a list"),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
for group_index, (group_name, cases) in enumerate(add_cases):
|
||||
dataset_id, document_id = create_document(f"chunk_add_contracts_{group_index}.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
for scenario_index, (payload, expected_code, expected_message) in enumerate(cases):
|
||||
scenario_name = f"{group_name}-{scenario_index}"
|
||||
before_payload = rest_client.get(base_path).json()
|
||||
assert before_payload["code"] == 0, (scenario_name, before_payload)
|
||||
before_total = before_payload["data"]["doc"]["chunk_count"]
|
||||
|
||||
res = rest_client.post(base_path, json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
if expected_code == 0:
|
||||
chunk = body["data"]["chunk"]
|
||||
assert chunk["dataset_id"] == dataset_id, (scenario_name, body)
|
||||
assert chunk["document_id"] == document_id, (scenario_name, body)
|
||||
assert chunk["content"] == payload["content"], (scenario_name, body)
|
||||
if "important_keywords" in payload:
|
||||
assert chunk["important_keywords"] == payload["important_keywords"], (scenario_name, body)
|
||||
if "questions" in payload:
|
||||
assert chunk["questions"] == [str(q).strip() for q in payload["questions"] if str(q).strip()], (scenario_name, body)
|
||||
if "tag_kwd" in payload:
|
||||
assert chunk["tag_kwd"] == payload["tag_kwd"], (scenario_name, body)
|
||||
after_payload = rest_client.get(base_path).json()
|
||||
assert after_payload["code"] == 0, (scenario_name, after_payload)
|
||||
assert after_payload["data"]["doc"]["chunk_count"] == before_total + 1, (scenario_name, after_payload)
|
||||
else:
|
||||
assert body["message"] == expected_message, (scenario_name, body)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_add_invalid_dataset_and_document_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_invalid_targets.txt")
|
||||
|
||||
invalid_dataset_res = rest_client.post(
|
||||
f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks",
|
||||
json={"content": "chunk test"},
|
||||
)
|
||||
assert invalid_dataset_res.status_code == 200
|
||||
invalid_dataset_payload = invalid_dataset_res.json()
|
||||
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
|
||||
assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload
|
||||
|
||||
invalid_document_res = rest_client.post(
|
||||
f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks",
|
||||
json={"content": "chunk test"},
|
||||
)
|
||||
assert invalid_document_res.status_code == 200
|
||||
invalid_document_payload = invalid_document_res.json()
|
||||
assert invalid_document_payload["code"] == 102, invalid_document_payload
|
||||
assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_add_repeated_and_deleted_document_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_repeat_deleted.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
|
||||
first_payload = rest_client.get(base_path).json()
|
||||
assert first_payload["code"] == 0, first_payload
|
||||
initial_count = first_payload["data"]["doc"]["chunk_count"]
|
||||
|
||||
first_add_res = rest_client.post(base_path, json={"content": "chunk test"})
|
||||
second_add_res = rest_client.post(base_path, json={"content": "chunk test"})
|
||||
first_add_payload = first_add_res.json()
|
||||
second_add_payload = second_add_res.json()
|
||||
assert first_add_payload["code"] == 0, first_add_payload
|
||||
assert second_add_payload["code"] == 0, second_add_payload
|
||||
assert first_add_payload["data"]["chunk"]["id"] == second_add_payload["data"]["chunk"]["id"], (first_add_payload, second_add_payload)
|
||||
|
||||
repeated_list_payload = rest_client.get(base_path).json()
|
||||
assert repeated_list_payload["code"] == 0, repeated_list_payload
|
||||
assert repeated_list_payload["data"]["doc"]["chunk_count"] == initial_count + 2, repeated_list_payload
|
||||
assert repeated_list_payload["data"]["total"] == 1, repeated_list_payload
|
||||
|
||||
delete_document_res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]})
|
||||
assert delete_document_res.status_code == 200
|
||||
delete_document_payload = delete_document_res.json()
|
||||
assert delete_document_payload["code"] == 0, delete_document_payload
|
||||
|
||||
add_after_delete_res = rest_client.post(base_path, json={"content": "chunk test"})
|
||||
assert add_after_delete_res.status_code == 200
|
||||
add_after_delete_payload = add_after_delete_res.json()
|
||||
assert add_after_delete_payload["code"] == 102, add_after_delete_payload
|
||||
assert add_after_delete_payload["message"] == f"You don't own the document {document_id}.", add_after_delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("count", [20])
|
||||
def test_chunk_concurrent_add_contract(rest_client, create_document, count):
|
||||
dataset_id, document_id = create_document("chunk_concurrent_add.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
baseline_payload = rest_client.get(base_path).json()
|
||||
assert baseline_payload["code"] == 0, baseline_payload
|
||||
initial_count = baseline_payload["data"]["doc"]["chunk_count"]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
results = list(
|
||||
executor.map(
|
||||
lambda index: rest_client.post(base_path, json={"content": f"chunk test {index}"}).json(),
|
||||
range(count),
|
||||
)
|
||||
)
|
||||
assert len(results) == count, results
|
||||
assert all(result["code"] == 0 for result in results), results
|
||||
|
||||
final_payload = rest_client.get(base_path).json()
|
||||
assert final_payload["code"] == 0, final_payload
|
||||
assert final_payload["data"]["doc"]["chunk_count"] == initial_count + count, final_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunks_list_empty_document(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_list_empty.txt")
|
||||
@ -114,11 +340,478 @@ def test_chunks_list_empty_document(rest_client, create_document):
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunks_delete_partial_invalid(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_delete_partial.txt")
|
||||
def test_chunk_delete_basic_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_delete_basic.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
delete_res = rest_client.delete(base_path, json={"chunk_ids": ["invalid_chunk_id"]})
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 102, delete_payload
|
||||
assert "expect 1" in delete_payload["message"], delete_payload
|
||||
|
||||
cases = [
|
||||
("none payload", None, 0, "", 5),
|
||||
("invalid only", {"chunk_ids": ["invalid_id"]}, 102, "rm_chunk deleted chunks 0, expect 1", 5),
|
||||
("delete first", lambda ids: {"chunk_ids": ids[:1]}, 0, "", 4),
|
||||
("delete generated", lambda ids: {"chunk_ids": ids}, 0, "", 1),
|
||||
("empty ids", {"chunk_ids": []}, 0, "", 5),
|
||||
]
|
||||
|
||||
for scenario_name, payload, expected_code, expected_message, remaining in cases:
|
||||
_reset_chunk_batch(rest_client, base_path)
|
||||
request_body = payload
|
||||
generated_ids = rest_client.get(base_path).json()["data"]["chunks"][1:]
|
||||
generated_ids = [chunk["id"] for chunk in generated_ids]
|
||||
if callable(payload):
|
||||
request_body = payload(generated_ids)
|
||||
res = rest_client.delete(base_path, json=request_body)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
if expected_message:
|
||||
assert body.get("message", "") == expected_message, (scenario_name, body)
|
||||
|
||||
list_payload = rest_client.get(base_path).json()
|
||||
assert list_payload["code"] == 0, (scenario_name, list_payload)
|
||||
assert len(list_payload["data"]["chunks"]) == remaining, (scenario_name, list_payload)
|
||||
assert list_payload["data"]["total"] == remaining, (scenario_name, list_payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_delete_partial_duplicate_repeat_and_invalid_target_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_delete_detail.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
|
||||
for scenario_name, payload_builder in (
|
||||
("invalid first", lambda ids: {"chunk_ids": ["invalid_id"] + ids}),
|
||||
("invalid middle", lambda ids: {"chunk_ids": ids[:1] + ["invalid_id"] + ids[1:]}),
|
||||
("invalid last", lambda ids: {"chunk_ids": ids + ["invalid_id"]}),
|
||||
):
|
||||
_, generated_ids = _reset_chunk_batch(rest_client, base_path)
|
||||
res = rest_client.delete(base_path, json=payload_builder(generated_ids))
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == 102, (scenario_name, body)
|
||||
assert body["message"] == "rm_chunk deleted chunks 4, expect 5", (scenario_name, body)
|
||||
|
||||
list_payload = rest_client.get(base_path).json()
|
||||
assert list_payload["code"] == 0, (scenario_name, list_payload)
|
||||
assert list_payload["data"]["total"] == 1, (scenario_name, list_payload)
|
||||
|
||||
_, generated_ids = _reset_chunk_batch(rest_client, base_path)
|
||||
duplicate_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids * 2})
|
||||
assert duplicate_res.status_code == 200
|
||||
duplicate_payload = duplicate_res.json()
|
||||
assert duplicate_payload["code"] == 0, duplicate_payload
|
||||
assert duplicate_payload["data"]["success_count"] == 4, duplicate_payload
|
||||
assert len(duplicate_payload["data"]["errors"]) == 4, duplicate_payload
|
||||
assert all(error.startswith("Duplicate chunk ids: ") for error in duplicate_payload["data"]["errors"]), duplicate_payload
|
||||
duplicate_list_payload = rest_client.get(base_path).json()
|
||||
assert duplicate_list_payload["code"] == 0, duplicate_list_payload
|
||||
assert duplicate_list_payload["data"]["total"] == 1, duplicate_list_payload
|
||||
|
||||
_, generated_ids = _reset_chunk_batch(rest_client, base_path)
|
||||
first_delete_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids})
|
||||
assert first_delete_res.status_code == 200
|
||||
assert first_delete_res.json()["code"] == 0
|
||||
second_delete_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids})
|
||||
assert second_delete_res.status_code == 200
|
||||
second_delete_payload = second_delete_res.json()
|
||||
assert second_delete_payload["code"] == 102, second_delete_payload
|
||||
assert second_delete_payload["message"] == "rm_chunk deleted chunks 0, expect 4", second_delete_payload
|
||||
|
||||
invalid_dataset_res = rest_client.delete(
|
||||
f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks",
|
||||
json={"chunk_ids": ["chunk-id"]},
|
||||
)
|
||||
assert invalid_dataset_res.status_code == 200
|
||||
invalid_dataset_payload = invalid_dataset_res.json()
|
||||
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
|
||||
assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload
|
||||
|
||||
invalid_document_res = rest_client.delete(
|
||||
f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks",
|
||||
json={"chunk_ids": ["chunk-id"]},
|
||||
)
|
||||
assert invalid_document_res.status_code == 200
|
||||
invalid_document_payload = invalid_document_res.json()
|
||||
assert invalid_document_payload["code"] == 102, invalid_document_payload
|
||||
assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_delete_web_legacy_basic_variants(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_delete_web_legacy_again.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
cases = [
|
||||
("web invalid id", {"chunk_ids": ["invalid_id"]}, 102, 5),
|
||||
("web delete first", lambda ids: {"chunk_ids": ids[:1]}, 0, 4),
|
||||
("web delete generated", lambda ids: {"chunk_ids": ids}, 0, 1),
|
||||
("web empty ids", {"chunk_ids": []}, 0, 5),
|
||||
]
|
||||
for scenario_name, payload, expected_code, remaining in cases:
|
||||
_, generated_ids = _reset_chunk_batch(rest_client, base_path)
|
||||
request_body = payload(generated_ids) if callable(payload) else payload
|
||||
res = rest_client.delete(base_path, json=request_body)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
list_payload = rest_client.get(base_path).json()
|
||||
assert list_payload["code"] == 0, (scenario_name, list_payload)
|
||||
assert list_payload["data"]["total"] == remaining, (scenario_name, list_payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_delete_concurrent_and_bulk_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_delete_bulk_contract.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
|
||||
rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True})
|
||||
for index in range(12):
|
||||
payload = rest_client.post(base_path, json={"content": f"chunk test {index}"}).json()
|
||||
assert payload["code"] == 0, payload
|
||||
ids_payload = rest_client.get(base_path).json()
|
||||
assert ids_payload["code"] == 0, ids_payload
|
||||
chunk_ids = [chunk["id"] for chunk in ids_payload["data"]["chunks"]]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
results = list(executor.map(lambda chunk_id: rest_client.delete(base_path, json={"chunk_ids": [chunk_id]}).json(), chunk_ids))
|
||||
assert len(results) == len(chunk_ids), results
|
||||
assert all(result["code"] == 0 for result in results), results
|
||||
|
||||
final_payload = rest_client.get(base_path).json()
|
||||
assert final_payload["code"] == 0, final_payload
|
||||
assert final_payload["data"]["total"] == 0, final_payload
|
||||
|
||||
rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True})
|
||||
for index in range(40):
|
||||
payload = rest_client.post(base_path, json={"content": f"bulk chunk {index}"}).json()
|
||||
assert payload["code"] == 0, payload
|
||||
bulk_ids_payload = rest_client.get(base_path, params={"page_size": 200}).json()
|
||||
assert bulk_ids_payload["code"] == 0, bulk_ids_payload
|
||||
bulk_ids = [chunk["id"] for chunk in bulk_ids_payload["data"]["chunks"]]
|
||||
bulk_res = rest_client.delete(base_path, json={"chunk_ids": bulk_ids})
|
||||
assert bulk_res.status_code == 200
|
||||
bulk_payload = bulk_res.json()
|
||||
assert bulk_payload["code"] == 0, bulk_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_list_default_get_id_and_invalid_target_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_list_core.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
baseline_id, generated_ids = _reset_chunk_batch(rest_client, base_path)
|
||||
|
||||
default_res = rest_client.get(base_path)
|
||||
assert default_res.status_code == 200
|
||||
default_payload = default_res.json()
|
||||
assert default_payload["code"] == 0, default_payload
|
||||
assert default_payload["data"]["total"] == 5, default_payload
|
||||
assert len(default_payload["data"]["chunks"]) == 5, default_payload
|
||||
|
||||
get_res = rest_client.get(f"{base_path}/{generated_ids[0]}")
|
||||
assert get_res.status_code == 200
|
||||
get_payload = get_res.json()
|
||||
assert get_payload["code"] == 0, get_payload
|
||||
assert get_payload["data"]["id"] == generated_ids[0], get_payload
|
||||
assert get_payload["data"]["doc_id"] == document_id, get_payload
|
||||
|
||||
invalid_get_res = rest_client.get(f"{base_path}/unknown")
|
||||
assert invalid_get_res.status_code == 200
|
||||
invalid_get_payload = invalid_get_res.json()
|
||||
assert invalid_get_payload["code"] == 102, invalid_get_payload
|
||||
assert invalid_get_payload["message"] == "Chunk not found!", invalid_get_payload
|
||||
|
||||
id_cases = [
|
||||
("id none", {"id": None}, 0, 5, None),
|
||||
("id empty", {"id": ""}, 0, 5, None),
|
||||
("id valid", {"id": generated_ids[0]}, 0, 1, generated_ids[0]),
|
||||
("id invalid", {"id": "unknown"}, 102, None, None),
|
||||
]
|
||||
for scenario_name, params, expected_code, expected_total, expected_id in id_cases:
|
||||
res = rest_client.get(base_path, params=params)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == expected_code, (scenario_name, payload)
|
||||
if expected_code == 0:
|
||||
assert payload["data"]["total"] == expected_total, (scenario_name, payload)
|
||||
if expected_id is not None:
|
||||
assert payload["data"]["chunks"][0]["id"] == expected_id, (scenario_name, payload)
|
||||
else:
|
||||
assert payload["message"] == f"Chunk not found: {dataset_id}/unknown", (scenario_name, payload)
|
||||
|
||||
invalid_dataset_res = rest_client.get(f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks")
|
||||
assert invalid_dataset_res.status_code == 200
|
||||
invalid_dataset_payload = invalid_dataset_res.json()
|
||||
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
|
||||
assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload
|
||||
|
||||
invalid_document_res = rest_client.get(f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks")
|
||||
assert invalid_document_res.status_code == 200
|
||||
invalid_document_payload = invalid_document_res.json()
|
||||
assert invalid_document_payload["code"] == 102, invalid_document_payload
|
||||
assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity")
|
||||
def test_chunk_list_keyword_and_invalid_param_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_list_keywords.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
_reset_chunk_batch(rest_client, base_path)
|
||||
|
||||
cases = [
|
||||
("keywords none", {"keywords": None}, 5),
|
||||
("keywords empty", {"keywords": ""}, 5),
|
||||
("keywords exact one", {"keywords": "1"}, 1),
|
||||
("keywords chunk", {"keywords": "chunk"}, 4),
|
||||
("keywords ragflow", {"keywords": "ragflow"}, 1),
|
||||
("keywords unknown", {"keywords": "unknown"}, 0),
|
||||
("invalid params ignored", {"a": "b"}, 5),
|
||||
]
|
||||
|
||||
for scenario_name, params, expected_total in cases:
|
||||
res = rest_client.get(base_path, params=params)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, (scenario_name, payload)
|
||||
assert payload["data"]["total"] == expected_total, (scenario_name, payload)
|
||||
assert len(payload["data"]["chunks"]) == expected_total, (scenario_name, payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity")
|
||||
def test_chunk_list_page_and_page_size_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_list_paging.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
_reset_chunk_batch(rest_client, base_path)
|
||||
|
||||
cases = [
|
||||
("page none", {"page": None, "page_size": 2}, 0, 2, ""),
|
||||
("page zero", {"page": 0, "page_size": 2}, 100, None, "ValueError('Search does not support negative slicing.')"),
|
||||
("page two", {"page": 2, "page_size": 2}, 0, 2, ""),
|
||||
("page three", {"page": 3, "page_size": 2}, 0, 1, ""),
|
||||
("page string", {"page": "3", "page_size": 2}, 0, 1, ""),
|
||||
("page negative", {"page": -1, "page_size": 2}, 100, None, "ValueError('Search does not support negative slicing.')"),
|
||||
("page alpha", {"page": "a", "page_size": 2}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
|
||||
("page_size none", {"page_size": None}, 0, 5, ""),
|
||||
("page_size zero", {"page_size": 0}, 0, 5, ""),
|
||||
("page_size one", {"page_size": 1}, 0, 1, ""),
|
||||
("page_size six", {"page_size": 6}, 0, 5, ""),
|
||||
("page_size string", {"page_size": "1"}, 0, 1, ""),
|
||||
("page_size negative", {"page_size": -1}, 0, 5, ""),
|
||||
("page_size alpha", {"page_size": "a"}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
|
||||
]
|
||||
|
||||
for scenario_name, params, expected_code, expected_total, expected_message in cases:
|
||||
res = rest_client.get(base_path, params=params)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == expected_code, (scenario_name, payload)
|
||||
if expected_code == 0:
|
||||
assert payload["data"]["total"] == 5, (scenario_name, payload)
|
||||
assert len(payload["data"]["chunks"]) == expected_total, (scenario_name, payload)
|
||||
else:
|
||||
assert expected_message in payload["message"], (scenario_name, payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_list_concurrent_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_list_concurrent.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
_reset_chunk_batch(rest_client, base_path)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
results = list(executor.map(lambda _: rest_client.get(base_path).json(), range(20)))
|
||||
assert len(results) == 20, results
|
||||
assert all(result["code"] == 0 for result in results), results
|
||||
assert all(result["data"]["total"] == 5 for result in results), results
|
||||
|
||||
|
||||
def _create_chunk_for_update(rest_client, create_document, file_name):
|
||||
dataset_id, document_id = create_document(file_name)
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
add_res = rest_client.post(base_path, json={"content": "chunk update test"})
|
||||
assert add_res.status_code == 200, add_res.text
|
||||
add_payload = add_res.json()
|
||||
assert add_payload["code"] == 0, add_payload
|
||||
chunk_id = add_payload["data"]["chunk"]["id"]
|
||||
return dataset_id, document_id, chunk_id, base_path
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_update_requires_auth(rest_client, create_document):
|
||||
_, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_auth.txt")
|
||||
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
|
||||
res = client.patch(f"{base_path}/{chunk_id}", json={"content": "updated"})
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, (scenario_name, payload)
|
||||
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_update_content_and_available_contract(rest_client, create_document):
|
||||
content_cases = [
|
||||
("content none", {"content": None}, 0, ""),
|
||||
("content empty", {"content": ""}, 102, "`content` is required"),
|
||||
("content text", {"content": "update chunk"}, 0, ""),
|
||||
("content spaces", {"content": " "}, 102, "`content` is required"),
|
||||
("content punctuation", {"content": "\n!?。;!?\"'"}, 0, ""),
|
||||
]
|
||||
for scenario_name, payload, expected_code, expected_message in content_cases:
|
||||
_, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, f"{scenario_name}.txt")
|
||||
res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
if expected_code != 0:
|
||||
assert body["message"] == expected_message, (scenario_name, body)
|
||||
|
||||
available_cases = [
|
||||
("available true", {"available": True}, 0, ""),
|
||||
("available true str", {"available": "True"}, 100, "invalid literal for int()"),
|
||||
("available one", {"available": 1}, 0, ""),
|
||||
("available false", {"available": False}, 0, ""),
|
||||
("available false str", {"available": "False"}, 100, "invalid literal for int()"),
|
||||
("available zero", {"available": 0}, 0, ""),
|
||||
]
|
||||
for scenario_name, payload, expected_code, expected_message in available_cases:
|
||||
_, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, f"{scenario_name}.txt")
|
||||
res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
if expected_code != 0:
|
||||
assert expected_message in body["message"], (scenario_name, body)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_update_keywords_questions_and_tag_contract(rest_client, create_document):
|
||||
_, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_fields.txt")
|
||||
cases = [
|
||||
("important keywords", {"important_keywords": ["a", "b", "c"]}, 0, ""),
|
||||
("important keywords empty", {"important_keywords": [""]}, 0, ""),
|
||||
("important keywords int", {"important_keywords": [1]}, 100, "TypeError"),
|
||||
("important keywords dup", {"important_keywords": ["a", "a"]}, 0, ""),
|
||||
("important keywords str", {"important_keywords": "abc"}, 102, "`important_keywords` should be a list"),
|
||||
("important keywords number", {"important_keywords": 123}, 102, "`important_keywords` should be a list"),
|
||||
("questions", {"questions": ["a", "b", "c"]}, 0, ""),
|
||||
("questions empty", {"questions": [""]}, 0, ""),
|
||||
("questions int", {"questions": [1]}, 100, "TypeError"),
|
||||
("questions dup", {"questions": ["a", "a"]}, 0, ""),
|
||||
("questions str", {"questions": "abc"}, 102, "`questions` should be a list"),
|
||||
("questions number", {"questions": 123}, 102, "`questions` should be a list"),
|
||||
("tag kwd", {"tag_kwd": ["tag1", "tag2"]}, 0, ""),
|
||||
("tag kwd empty", {"tag_kwd": [""]}, 0, ""),
|
||||
("tag kwd int in list", {"tag_kwd": [1]}, 102, "`tag_kwd` must be a list of strings"),
|
||||
("tag kwd dup", {"tag_kwd": ["tag", "tag"]}, 0, ""),
|
||||
("tag kwd str", {"tag_kwd": "tag"}, 102, "`tag_kwd` should be a list"),
|
||||
("tag kwd number", {"tag_kwd": 123}, 102, "`tag_kwd` should be a list"),
|
||||
]
|
||||
for scenario_name, payload, expected_code, expected_message in cases:
|
||||
res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
if expected_code != 0:
|
||||
assert expected_message in body["message"], (scenario_name, body)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_update_invalid_target_and_param_contract(rest_client, create_document):
|
||||
dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_invalid_targets.txt")
|
||||
|
||||
invalid_dataset_res = rest_client.patch(
|
||||
f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks/{chunk_id}",
|
||||
json={"content": "updated"},
|
||||
)
|
||||
assert invalid_dataset_res.status_code == 200
|
||||
invalid_dataset_payload = invalid_dataset_res.json()
|
||||
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
|
||||
assert invalid_dataset_payload["message"] in {
|
||||
f"You don't own the dataset {INVALID_ID_32}.",
|
||||
f"Can't find this chunk {chunk_id}",
|
||||
}, invalid_dataset_payload
|
||||
|
||||
invalid_document_res = rest_client.patch(
|
||||
f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks/{chunk_id}",
|
||||
json={"content": "updated"},
|
||||
)
|
||||
assert invalid_document_res.status_code == 200
|
||||
invalid_document_payload = invalid_document_res.json()
|
||||
assert invalid_document_payload["code"] == 102, invalid_document_payload
|
||||
assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload
|
||||
|
||||
invalid_chunk_res = rest_client.patch(
|
||||
f"{base_path}/{INVALID_ID_32}",
|
||||
json={"content": "updated"},
|
||||
)
|
||||
assert invalid_chunk_res.status_code == 200
|
||||
invalid_chunk_payload = invalid_chunk_res.json()
|
||||
assert invalid_chunk_payload["code"] == 102, invalid_chunk_payload
|
||||
assert invalid_chunk_payload["message"] == f"Can't find this chunk {INVALID_ID_32}", invalid_chunk_payload
|
||||
|
||||
for scenario_name, payload in (
|
||||
("unknown key", {"unknown_key": "unknown_value"}),
|
||||
("empty payload", {}),
|
||||
):
|
||||
res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == 0, (scenario_name, body)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunk_update_repeated_concurrent_and_deleted_document_contract(rest_client, create_document):
|
||||
dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update(
|
||||
rest_client, create_document, "chunk_update_repeated_concurrent_deleted.txt"
|
||||
)
|
||||
|
||||
first_res = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "chunk test 1"})
|
||||
assert first_res.status_code == 200
|
||||
assert first_res.json()["code"] == 0
|
||||
|
||||
second_res = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "chunk test 2"})
|
||||
assert second_res.status_code == 200
|
||||
assert second_res.json()["code"] == 0
|
||||
|
||||
get_after_repeat = rest_client.get(f"{base_path}/{chunk_id}")
|
||||
assert get_after_repeat.status_code == 200
|
||||
get_after_repeat_payload = get_after_repeat.json()
|
||||
assert get_after_repeat_payload["code"] == 0, get_after_repeat_payload
|
||||
assert get_after_repeat_payload["data"]["content_with_weight"] == "chunk test 2", get_after_repeat_payload
|
||||
|
||||
chunk_ids = [chunk_id]
|
||||
for index in range(3):
|
||||
add_res = rest_client.post(base_path, json={"content": f"concurrent update {index}"})
|
||||
assert add_res.status_code == 200, add_res.text
|
||||
add_payload = add_res.json()
|
||||
assert add_payload["code"] == 0, add_payload
|
||||
chunk_ids.append(add_payload["data"]["chunk"]["id"])
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = []
|
||||
for index in range(20):
|
||||
target_id = chunk_ids[index % len(chunk_ids)]
|
||||
futures.append(
|
||||
executor.submit(
|
||||
lambda cid, i: rest_client.patch(
|
||||
f"{base_path}/{cid}",
|
||||
json={"content": f"update chunk test {i}"},
|
||||
).json(),
|
||||
target_id,
|
||||
index,
|
||||
)
|
||||
)
|
||||
results = [future.result() for future in futures]
|
||||
assert len(results) == 20, results
|
||||
assert all(item["code"] == 0 for item in results), results
|
||||
|
||||
delete_document_res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]})
|
||||
assert delete_document_res.status_code == 200
|
||||
assert delete_document_res.json()["code"] == 0
|
||||
|
||||
update_after_delete = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "after delete"})
|
||||
assert update_after_delete.status_code == 200
|
||||
update_after_delete_payload = update_after_delete.json()
|
||||
assert update_after_delete_payload["code"] == 102, update_after_delete_payload
|
||||
assert update_after_delete_payload["message"] in {
|
||||
f"You don't own the document {document_id}.",
|
||||
f"Can't find this chunk {chunk_id}",
|
||||
}, update_after_delete_payload
|
||||
|
||||
425
test/testcases/restful_api/test_dify_retrieval_routes_unit.py
Normal file
425
test/testcases/restful_api/test_dify_retrieval_routes_unit.py
Normal file
@ -0,0 +1,425 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import inspect
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _DummyKB:
|
||||
def __init__(self, tenant_id="tenant-1", embd_id="embd-1", tenant_embd_id=1):
|
||||
self.tenant_id = tenant_id
|
||||
self.embd_id = embd_id
|
||||
self.tenant_embd_id = tenant_embd_id
|
||||
|
||||
|
||||
class _DummyRetriever:
|
||||
async def retrieval(self, *_args, **_kwargs):
|
||||
return {
|
||||
"chunks": [
|
||||
{"doc_id": "doc-1", "content_with_weight": "chunk-content", "similarity": 0.8, "docnm_kwd": "doc-title", "vector": [0.1]}
|
||||
]
|
||||
}
|
||||
|
||||
def retrieval_by_children(self, chunks, _tenant_ids):
|
||||
return chunks
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_dify_retrieval_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
deepdoc_pkg = ModuleType("deepdoc")
|
||||
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
|
||||
deepdoc_parser_pkg.__path__ = []
|
||||
|
||||
class _StubPdfParser:
|
||||
pass
|
||||
|
||||
class _StubExcelParser:
|
||||
pass
|
||||
|
||||
class _StubDocxParser:
|
||||
pass
|
||||
|
||||
deepdoc_parser_pkg.PdfParser = _StubPdfParser
|
||||
deepdoc_parser_pkg.ExcelParser = _StubExcelParser
|
||||
deepdoc_parser_pkg.DocxParser = _StubDocxParser
|
||||
deepdoc_pkg.parser = deepdoc_parser_pkg
|
||||
monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg)
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg)
|
||||
|
||||
deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser")
|
||||
deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module)
|
||||
|
||||
deepdoc_parser_utils = ModuleType("deepdoc.parser.utils")
|
||||
deepdoc_parser_utils.get_text = lambda *_args, **_kwargs: ""
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
|
||||
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
|
||||
|
||||
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _MockModelConfig:
|
||||
def __init__(self, tenant_id, model_name):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_name = model_name
|
||||
self.llm_factory = "Builtin"
|
||||
self.api_key = "fake-api-key"
|
||||
self.api_base = "https://api.example.com"
|
||||
self.model_type = "chat"
|
||||
self.max_tokens = 8192
|
||||
self.used_tokens = 0
|
||||
self.status = 1
|
||||
self.id = 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tenant_id": self.tenant_id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"api_key": self.api_key,
|
||||
"api_base": self.api_base,
|
||||
"model_type": self.model_type,
|
||||
"max_tokens": self.max_tokens,
|
||||
"used_tokens": self.used_tokens,
|
||||
"status": self.status,
|
||||
"id": self.id,
|
||||
}
|
||||
|
||||
class _StubTenantService:
|
||||
@staticmethod
|
||||
def get_by_id(tenant_id):
|
||||
return True, SimpleNamespace(
|
||||
id=tenant_id,
|
||||
llm_id="chat-model",
|
||||
embd_id="embd-model",
|
||||
asr_id="asr-model",
|
||||
img2txt_id="img2txt-model",
|
||||
rerank_id="rerank-model",
|
||||
tts_id="tts-model",
|
||||
)
|
||||
|
||||
class _StubTenantLLMService:
|
||||
@staticmethod
|
||||
def get_api_key(tenant_id, model_name):
|
||||
return _MockModelConfig(tenant_id, model_name)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
if "@" in model_name:
|
||||
parts = model_name.split("@")
|
||||
return parts[0], parts[1]
|
||||
return model_name, None
|
||||
|
||||
tenant_llm_service_mod.TenantService = _StubTenantService
|
||||
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
|
||||
|
||||
class _StubLLMFactoriesService:
|
||||
pass
|
||||
|
||||
tenant_llm_service_mod.LLMFactoriesService = _StubLLMFactoriesService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
|
||||
llm_service_mod = ModuleType("api.db.services.llm_service")
|
||||
|
||||
class _StubLLM:
|
||||
def __init__(self, llm_name):
|
||||
self.llm_name = llm_name
|
||||
self.is_tools = False
|
||||
|
||||
class _StubLLMBundle:
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.model_config = model_config
|
||||
self.lang = lang
|
||||
|
||||
def encode(self, texts: list):
|
||||
import numpy as np
|
||||
|
||||
return [np.array([0.1, 0.2, 0.3]) for _ in texts], len(texts) * 10
|
||||
|
||||
llm_service_mod.LLMService = SimpleNamespace(query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else [])
|
||||
llm_service_mod.LLMBundle = _StubLLMBundle
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
|
||||
|
||||
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
|
||||
|
||||
class _MockModelConfig2:
|
||||
def __init__(self, tenant_id, model_name):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_name = model_name
|
||||
self.llm_factory = "Builtin"
|
||||
self.api_key = "fake-api-key"
|
||||
self.api_base = "https://api.example.com"
|
||||
self.model_type = "chat"
|
||||
self.max_tokens = 8192
|
||||
self.used_tokens = 0
|
||||
self.status = 1
|
||||
self.id = 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tenant_id": self.tenant_id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"api_key": self.api_key,
|
||||
"api_base": self.api_base,
|
||||
"model_type": self.model_type,
|
||||
"max_tokens": self.max_tokens,
|
||||
"used_tokens": self.used_tokens,
|
||||
"status": self.status,
|
||||
"id": self.id,
|
||||
}
|
||||
|
||||
def _get_model_config_by_id(tenant_model_id: int, allowed_tenant_ids=None, requester_tenant_id=None) -> dict:
|
||||
mock_tenant_id = "tenant-1"
|
||||
if allowed_tenant_ids is not None:
|
||||
if isinstance(allowed_tenant_ids, str):
|
||||
allowed_tenant_ids = {allowed_tenant_ids}
|
||||
else:
|
||||
allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id}
|
||||
if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id:
|
||||
raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized")
|
||||
return _MockModelConfig2(mock_tenant_id, "model-1").to_dict()
|
||||
|
||||
def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
|
||||
if not model_name:
|
||||
raise Exception("Model Name is required")
|
||||
return _MockModelConfig2(tenant_id, model_name).to_dict()
|
||||
|
||||
def _get_tenant_default_model_by_type(tenant_id: str, model_type):
|
||||
return _MockModelConfig2(tenant_id, "chat-model").to_dict()
|
||||
|
||||
tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id
|
||||
tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name
|
||||
tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type
|
||||
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
|
||||
|
||||
module_name = "test_dify_retrieval_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload)))
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_success_with_metadata_and_kg(monkeypatch):
|
||||
module = _load_dify_retrieval_module(monkeypatch)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"knowledge_id": "kb-1",
|
||||
"query": "hello",
|
||||
"use_kg": True,
|
||||
"retrieval_setting": {"score_threshold": 0.1, "top_k": 3},
|
||||
"metadata_condition": {"conditions": [{"name": "author", "comparison_operator": "is", "value": "alice"}], "logic": "and"},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(module, "jsonify", lambda payload: payload)
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [{"doc_id": "doc-1"}])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
|
||||
monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", []))
|
||||
monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: [])
|
||||
|
||||
retriever = _DummyRetriever()
|
||||
monkeypatch.setattr(module.settings, "retriever", retriever)
|
||||
|
||||
class _DummyKgRetriever:
|
||||
async def retrieval(self, *_args, **_kwargs):
|
||||
return {
|
||||
"doc_id": "doc-2",
|
||||
"content_with_weight": "kg-content",
|
||||
"similarity": 0.9,
|
||||
"docnm_kwd": "kg-title",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(module.settings, "kg_retriever", _DummyKgRetriever())
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda doc_id: (True, SimpleNamespace(meta_fields={"origin": f"meta-{doc_id}"})))
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
|
||||
|
||||
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
|
||||
assert "records" in res, res
|
||||
assert len(res["records"]) == 2, res
|
||||
top = res["records"][0]
|
||||
assert top["title"] == "kg-title", res
|
||||
assert top["metadata"]["doc_id"] == "doc-2", res
|
||||
assert "score" in top, res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_kb_not_found(monkeypatch):
|
||||
module = _load_dify_retrieval_module(monkeypatch)
|
||||
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-missing", "query": "hello"})
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
|
||||
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
|
||||
assert res["code"] == module.RetCode.NOT_FOUND, res
|
||||
assert "Knowledgebase not found" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_not_found_exception_mapping(monkeypatch):
|
||||
module = _load_dify_retrieval_module(monkeypatch)
|
||||
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
|
||||
|
||||
class _BrokenRetriever:
|
||||
async def retrieval(self, *_args, **_kwargs):
|
||||
raise RuntimeError("chunk_not_found_error")
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", _BrokenRetriever())
|
||||
|
||||
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
|
||||
assert res["code"] == module.RetCode.NOT_FOUND, res
|
||||
assert "No chunk found" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_generic_exception_mapping(monkeypatch):
|
||||
module = _load_dify_retrieval_module(monkeypatch)
|
||||
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
|
||||
|
||||
class _BrokenRetriever:
|
||||
async def retrieval(self, *_args, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", _BrokenRetriever())
|
||||
|
||||
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
|
||||
assert res["code"] == module.RetCode.SERVER_ERROR, res
|
||||
assert "boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_read_retrieval_request_from_get_args(monkeypatch):
|
||||
module = _load_dify_retrieval_module(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"request",
|
||||
SimpleNamespace(
|
||||
method="GET",
|
||||
args={
|
||||
"knowledge_id": "kb-1",
|
||||
"query": "hello",
|
||||
"use_kg": "true",
|
||||
"top_k": "12",
|
||||
"score_threshold": "0.66",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
req = _run(module._read_retrieval_request())
|
||||
assert req["knowledge_id"] == "kb-1", req
|
||||
assert req["query"] == "hello", req
|
||||
assert req["use_kg"] is True, req
|
||||
assert req["retrieval_setting"]["top_k"] == 12, req
|
||||
assert req["retrieval_setting"]["score_threshold"] == 0.66, req
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_read_retrieval_request_from_post_json(monkeypatch):
|
||||
module = _load_dify_retrieval_module(monkeypatch)
|
||||
payload = {"knowledge_id": "kb-1", "query": "hello"}
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(method="POST", args={}))
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
|
||||
|
||||
req = _run(module._read_retrieval_request())
|
||||
assert req == payload, req
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_argument_error_messages(monkeypatch):
|
||||
module = _load_dify_retrieval_module(monkeypatch)
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"knowledge_id": "kb-1",
|
||||
"query": "hello",
|
||||
"retrieval_setting": {"top_k": "not-int", "score_threshold": "not-float"},
|
||||
},
|
||||
)
|
||||
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
|
||||
assert "invalid or malformed arguments:" in res["message"], res
|
||||
|
||||
_set_request_json(monkeypatch, module, {})
|
||||
res_missing = _run(inspect.unwrap(module.retrieval)("tenant-1"))
|
||||
assert res_missing["code"] == module.RetCode.ARGUMENT_ERROR, res_missing
|
||||
assert "required arguments are missing:" in res_missing["message"], res_missing
|
||||
|
||||
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1"})
|
||||
res_missing_query = _run(inspect.unwrap(module.retrieval)("tenant-1"))
|
||||
assert res_missing_query["code"] == module.RetCode.ARGUMENT_ERROR, res_missing_query
|
||||
assert "query" in res_missing_query["message"], res_missing_query
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{"knowledge_id": "kb-1", "query": "hello", "retrieval_setting": "bad-type"},
|
||||
)
|
||||
res_wrong_type = _run(inspect.unwrap(module.retrieval)("tenant-1"))
|
||||
assert res_wrong_type["code"] == module.RetCode.ARGUMENT_ERROR, res_wrong_type
|
||||
assert "retrieval_setting must be an object" in res_wrong_type["message"], res_wrong_type
|
||||
@ -14,7 +14,12 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pytest
|
||||
import requests
|
||||
from test.testcases.configs import HOST_ADDRESS, INVALID_API_TOKEN, VERSION
|
||||
from test.testcases.restful_api.helpers.client import RestClient
|
||||
from test.testcases.utils import wait_for
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
@ -107,3 +112,264 @@ def test_retrieval_compatibility_requires_auth(rest_client_noauth):
|
||||
# token_required preserves legacy payload code/message while returning HTTP 401.
|
||||
assert payload["code"] == 0, payload
|
||||
assert payload["message"] == "`Authorization` can't be empty", payload
|
||||
|
||||
|
||||
@wait_for(20, 1, "Retrieval indexing timeout in RESTful batch 10 tests")
|
||||
def _retrieval_has_question(rest_client, dataset_id, question):
|
||||
res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]})
|
||||
if res.status_code != 200:
|
||||
return False
|
||||
payload = res.json()
|
||||
if payload["code"] != 0:
|
||||
return False
|
||||
return len(payload["data"]["chunks"]) > 0
|
||||
|
||||
|
||||
@wait_for(20, 1, "Retrieval indexing timeout waiting for chunk presence in RESTful batch 10 tests")
|
||||
def _retrieval_has_chunks(rest_client, dataset_id, question, chunk_ids):
|
||||
res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]})
|
||||
if res.status_code != 200:
|
||||
return False
|
||||
payload = res.json()
|
||||
if payload["code"] != 0:
|
||||
return False
|
||||
retrieved_ids = {chunk["id"] for chunk in payload["data"]["chunks"]}
|
||||
expected_ids = set(chunk_ids)
|
||||
return expected_ids.issubset(retrieved_ids)
|
||||
|
||||
|
||||
@wait_for(20, 1, "Retrieval indexing timeout waiting for chunk deletion in RESTful batch 10 tests")
|
||||
def _retrieval_lacks_chunks(rest_client, dataset_id, question, chunk_ids):
|
||||
res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]})
|
||||
if res.status_code != 200:
|
||||
return False
|
||||
payload = res.json()
|
||||
if payload["code"] != 0:
|
||||
return False
|
||||
retrieved_ids = {chunk["id"] for chunk in payload["data"]["chunks"]}
|
||||
expected_ids = set(chunk_ids)
|
||||
return expected_ids.isdisjoint(retrieved_ids)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_requires_auth_contract(ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
for scenario_name, token, expected_code, expected_message in (
|
||||
("missing token", None, 0, "`Authorization` can't be empty"),
|
||||
("invalid token", INVALID_API_TOKEN, 109, "Authentication error: API key is invalid!"),
|
||||
):
|
||||
client = RestClient(token=token)
|
||||
res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == expected_code, (scenario_name, payload)
|
||||
assert payload["message"] == expected_message, (scenario_name, payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_page_and_page_size_contract(rest_client, ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
cases = [
|
||||
("page none", {"question": "chunk", "dataset_ids": [dataset_id], "page": None, "page_size": 2}, 100, "TypeError"),
|
||||
("page zero", {"question": "chunk", "dataset_ids": [dataset_id], "page": 0, "page_size": 2}, 0, ""),
|
||||
("page two", {"question": "chunk", "dataset_ids": [dataset_id], "page": 2, "page_size": 2}, 0, ""),
|
||||
("page three", {"question": "chunk", "dataset_ids": [dataset_id], "page": 3, "page_size": 2}, 0, ""),
|
||||
("page str", {"question": "chunk", "dataset_ids": [dataset_id], "page": "3", "page_size": 2}, 0, ""),
|
||||
("page negative", {"question": "chunk", "dataset_ids": [dataset_id], "page": -1, "page_size": 2}, 0, ""),
|
||||
("page alpha", {"question": "chunk", "dataset_ids": [dataset_id], "page": "a", "page_size": 2}, 100, "invalid literal for int()"),
|
||||
("page_size none", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": None}, 100, "TypeError"),
|
||||
("page_size one", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": 1}, 0, ""),
|
||||
("page_size five", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": 5}, 0, ""),
|
||||
("page_size str", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": "1"}, 0, ""),
|
||||
("page_size alpha", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": "a"}, 100, "invalid literal for int()"),
|
||||
]
|
||||
for scenario_name, payload, expected_code, expected_message in cases:
|
||||
res = rest_client.post("/retrieval", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
if expected_code != 0:
|
||||
assert expected_message in body["message"], (scenario_name, body)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_highlight_keyword_and_invalid_params_contract(rest_client, ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
|
||||
highlight_cases = [
|
||||
("highlight true", True, True),
|
||||
("highlight true str", "True", True),
|
||||
("highlight false", False, False),
|
||||
("highlight false str", "False", False),
|
||||
("highlight none", None, False),
|
||||
]
|
||||
for scenario_name, highlight_value, expect_highlight in highlight_cases:
|
||||
res = rest_client.post(
|
||||
"/retrieval",
|
||||
json={"question": "chunk", "dataset_ids": [dataset_id], "highlight": highlight_value},
|
||||
)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == 0, (scenario_name, body)
|
||||
for chunk in body["data"]["chunks"]:
|
||||
if expect_highlight:
|
||||
assert "highlight" in chunk, (scenario_name, body)
|
||||
else:
|
||||
assert "highlight" not in chunk, (scenario_name, body)
|
||||
|
||||
invalid_highlight = rest_client.post(
|
||||
"/retrieval",
|
||||
json={"question": "chunk", "dataset_ids": [dataset_id], "highlight": "not_bool"},
|
||||
)
|
||||
assert invalid_highlight.status_code == 200
|
||||
invalid_highlight_payload = invalid_highlight.json()
|
||||
assert invalid_highlight_payload["code"] == 102, invalid_highlight_payload
|
||||
assert invalid_highlight_payload["message"] == "`highlight` should be a boolean", invalid_highlight_payload
|
||||
|
||||
for scenario_name, keyword_value in (
|
||||
("keyword true", True),
|
||||
("keyword true str", "True"),
|
||||
("keyword false", False),
|
||||
("keyword false str", "False"),
|
||||
("keyword none", None),
|
||||
):
|
||||
keyword_res = rest_client.post(
|
||||
"/retrieval",
|
||||
json={"question": "chunk test", "dataset_ids": [dataset_id], "keyword": keyword_value},
|
||||
)
|
||||
assert keyword_res.status_code == 200, (scenario_name, keyword_res.text)
|
||||
keyword_payload = keyword_res.json()
|
||||
assert keyword_payload["code"] == 0, (scenario_name, keyword_payload)
|
||||
assert isinstance(keyword_payload["data"]["chunks"], list), (scenario_name, keyword_payload)
|
||||
|
||||
invalid_params_res = rest_client.post(
|
||||
"/retrieval",
|
||||
json={"question": "chunk", "dataset_ids": [dataset_id], "a": "b"},
|
||||
)
|
||||
assert invalid_params_res.status_code == 200
|
||||
invalid_params_payload = invalid_params_res.json()
|
||||
assert invalid_params_payload["code"] == 0, invalid_params_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_vector_similarity_and_top_k_contract(rest_client, ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
cases = [
|
||||
("vector 0", {"vector_similarity_weight": 0}, 0, ""),
|
||||
("vector 0.5", {"vector_similarity_weight": 0.5}, 0, ""),
|
||||
("vector 10", {"vector_similarity_weight": 10}, 0, ""),
|
||||
("vector alpha", {"vector_similarity_weight": "a"}, 100, "could not convert string to float"),
|
||||
("top_k 10", {"top_k": 10}, 0, ""),
|
||||
("top_k 1", {"top_k": 1}, 0, ""),
|
||||
("top_k -1", {"top_k": -1}, 102, "`top_k` must be greater than 0"),
|
||||
("top_k alpha", {"top_k": "a"}, 100, "invalid literal for int()"),
|
||||
]
|
||||
for scenario_name, updates, expected_code, expected_message in cases:
|
||||
payload = {"question": "chunk", "dataset_ids": [dataset_id]}
|
||||
payload.update(updates)
|
||||
res = rest_client.post("/retrieval", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
if expected_code != 0:
|
||||
assert expected_message in body["message"], (scenario_name, body)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_rerank_unknown_contract(rest_client, ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
res = rest_client.post(
|
||||
"/retrieval",
|
||||
json={"question": "chunk", "dataset_ids": [dataset_id], "rerank_id": "unknown"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] != 0, payload
|
||||
assert payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_concurrent_contract(rest_client, ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
payload = {"question": "chunk", "dataset_ids": [dataset_id]}
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
results = list(executor.map(lambda _: rest_client.post("/retrieval", json=payload).json(), range(20)))
|
||||
assert len(results) == 20, results
|
||||
assert all(result["code"] == 0 for result in results), results
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_deleted_chunk_not_in_retrieval_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("retrieval_deleted_chunk.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
content = "UNIQUE_TEST_CONTENT_12520_REST"
|
||||
|
||||
add_res = rest_client.post(base_path, json={"content": content})
|
||||
assert add_res.status_code == 200
|
||||
add_payload = add_res.json()
|
||||
assert add_payload["code"] == 0, add_payload
|
||||
chunk_id = add_payload["data"]["chunk"]["id"]
|
||||
|
||||
_retrieval_has_chunks(rest_client, dataset_id, content, [chunk_id])
|
||||
|
||||
delete_res = rest_client.delete(base_path, json={"chunk_ids": [chunk_id]})
|
||||
assert delete_res.status_code == 200
|
||||
assert delete_res.json()["code"] == 0
|
||||
_retrieval_lacks_chunks(rest_client, dataset_id, content, [chunk_id])
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_deleted_chunks_batch_not_in_retrieval_contract(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("retrieval_deleted_chunks_batch.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
chunk_ids = []
|
||||
for index in range(3):
|
||||
content = f"BATCH_DELETE_TEST_CHUNK_{index}_REST_12520"
|
||||
add_res = rest_client.post(base_path, json={"content": content})
|
||||
assert add_res.status_code == 200
|
||||
add_payload = add_res.json()
|
||||
assert add_payload["code"] == 0, add_payload
|
||||
chunk_ids.append(add_payload["data"]["chunk"]["id"])
|
||||
_retrieval_has_chunks(rest_client, dataset_id, "BATCH_DELETE_TEST_CHUNK", chunk_ids)
|
||||
|
||||
delete_res = rest_client.delete(base_path, json={"chunk_ids": chunk_ids})
|
||||
assert delete_res.status_code == 200
|
||||
assert delete_res.json()["code"] == 0
|
||||
_retrieval_lacks_chunks(rest_client, dataset_id, "BATCH_DELETE_TEST_CHUNK", chunk_ids)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_related_questions_contract(auth, rest_client, rest_client_noauth):
|
||||
tokens_res = requests.get(
|
||||
f"{HOST_ADDRESS}/api/{VERSION}/system/tokens",
|
||||
headers={"Authorization": auth},
|
||||
timeout=30,
|
||||
)
|
||||
assert tokens_res.status_code == 200, tokens_res.text
|
||||
tokens_payload = tokens_res.json()
|
||||
assert tokens_payload["code"] == 0, tokens_payload
|
||||
assert tokens_payload["data"], tokens_payload
|
||||
beta_token = tokens_payload["data"][0]["beta"]
|
||||
|
||||
success_client = RestClient(token=beta_token)
|
||||
success_res = success_client.post("/searchbots/related_questions", json={"question": "ragflow", "industry": "search"})
|
||||
assert success_res.status_code == 200
|
||||
success_payload = success_res.json()
|
||||
assert success_payload["code"] == 0, success_payload
|
||||
assert isinstance(success_payload["data"], list), success_payload
|
||||
|
||||
missing_res = rest_client.post("/searchbots/related_questions", json={"industry": "search"})
|
||||
assert missing_res.status_code == 200
|
||||
missing_payload = missing_res.json()
|
||||
assert missing_payload["code"] == 101, missing_payload
|
||||
assert "question" in missing_payload["message"], missing_payload
|
||||
|
||||
invalid_auth_res = rest_client_noauth.post(
|
||||
"/searchbots/related_questions",
|
||||
json={"question": "ragflow", "industry": "search"},
|
||||
headers={"Authorization": "invalid"},
|
||||
)
|
||||
assert invalid_auth_res.status_code == 200
|
||||
invalid_auth_payload = invalid_auth_res.json()
|
||||
assert invalid_auth_payload["code"] == 102, invalid_auth_payload
|
||||
assert "Authorization is not valid!" in invalid_auth_payload["message"], invalid_auth_payload
|
||||
|
||||
Reference in New Issue
Block a user